@dataclasses.dataclass
class Auth:
"""
用于获取和设置访问令牌的类.
"""
access_token: Optional[str] = os.getenv("BAIDU_ACCESS_TOKEN")
access_expiredAt: Optional[str] = os.getenv("BAIDU_EXPIREDAT") # ISO8601格式字符串
access_refresh_token: Optional[str] = os.getenv("BAIDU_REFRESH_TOKEN")
def __post_init__(self) -> None:
"""
初始化 Auth 对象
"""
load_env()
if not self.access_token:
# 如果没有 access_token, 需要获取
self._get_access_token()
def save_info(self, res: dict) -> None:
"""
保存 access_token 信息
"""
# 验证成功,返回 access_token
self.access_token = res["access_token"]
timestamp = res["expires_in"] + int(datetime.now().timestamp())
self.access_expiredAt = datetime.fromtimestamp(timestamp).isoformat()
self.access_refresh_token = res["refresh_token"]
# 将 access_token 存入环境文件
project_env_path = os.path.join(os.getcwd(), ".env")
system_env_path = os.path.join(os.path.expanduser("~"), ".env.panbd")
if os.path.exists(project_env_path):
env_path = project_env_path
elif os.path.exists(system_env_path):
env_path = system_env_path
else:
# 默认创建项目目录下的 .env
env_path = project_env_path
Path(env_path).touch()
if self.access_token:
set_key(env_path, "BAIDU_ACCESS_TOKEN", self.access_token)
if self.access_expiredAt:
set_key(env_path, "BAIDU_EXPIREDAT", self.access_expiredAt)
if self.access_refresh_token:
set_key(env_path, "BAIDU_REFRESH_TOKEN", self.access_refresh_token)
print(f"✅ access_token 获取成功,并保存: {env_path}")
def _get_access_token(self) -> "Auth":
d = {
"method": "get",
"url": "https://openapi.baidu.com/oauth/2.0/authorize",
"params": {
"response_type": "code",
"client_id": os.getenv("BAIDU_API_APPKEY"),
"redirect_uri": "oob",
"scope": "basic,netdisk",
"device_id": os.getenv("BAIDU_API_APPID"),
},
}
query_string = urlencode(d["params"])
url = f"{d['url']}?{query_string}"
print("请在浏览器中登录您的百度账号并打开下面链接获取 code 参数")
print(f"\n{url}\n")
code = input("请输入获得的 code 值:")
# 访问链接获取 access_token
d1 = {
"method": "get",
"url": "https://openapi.baidu.com/oauth/2.0/token",
"params": {
"grant_type": "authorization_code",
"code": code,
"client_id": os.getenv("BAIDU_API_APPKEY"),
"client_secret": os.getenv("BAIDU_API_SECRETKEY"),
"redirect_uri": "oob",
},
}
# 验证返回值
try:
res = requests.request(**d1, headers=HEADERS)
res.raise_for_status()
res = res.json()
validate(instance=res, schema=schema_)
except Exception as e:
print(f"❌ access_token 获取失败: {e}")
sys.exit(1)
self.save_info(res)
return self
def _is_token_expired(self) -> bool:
"""判断 access_token 是否过期"""
if not self.access_expiredAt:
return False # 没有过期时间,认为没过期
try:
expire_dt = datetime.fromisoformat(self.access_expiredAt)
now = datetime.now()
return now >= expire_dt
except ValueError:
print("❌ Invalid access_expiredAt format")
return True
@property
def token(self) -> str | None:
"""
每次访问时自动检查是否过期并刷新
"""
if self._is_token_expired():
print("⚠️ access_token 已过期,正在刷新...")
self.refresh_access_token()
return self.access_token
def set_access_token(self, access_token: str) -> "Auth":
"""
设置 access_token, 一次性的,不会保存到环境变量中,且不会检查是否过期
Args:
access_token (str): 访问令牌
"""
self.access_token = access_token
self.access_expiredAt = datetime.fromtimestamp(
int(datetime.now().timestamp()) + 2592000
).isoformat()
self.access_refresh_token = None
return self
@retry(stop=stop_after_attempt(3), wait=wait_random(min=1, max=5))
def refresh_access_token(self) -> "Auth":
"""强制刷新 access_token"""
d1 = {
"method": "get",
"url": "https://openapi.baidu.com/oauth/2.0/token",
"params": {
"grant_type": "refresh_token",
"refresh_token": os.getenv("BAIDU_REFRESH_TOKEN"),
"client_id": os.getenv("BAIDU_API_APPKEY"),
"client_secret": os.getenv("BAIDU_API_SECRETKEY"),
},
}
if not d1["params"]["refresh_token"]:
print("❌ 无法刷新, 请设置环境变量 BAIDU_REFRESH_TOKEN")
sys.exit(1)
if not d1["params"]["client_id"]:
print("❌ 无法刷新, 请设置环境变量 BAIDU_API_APPKEY")
sys.exit(1)
if not d1["params"]["client_secret"]:
print("❌ 无法刷新 , 请设置环境变量 BAIDU_API_SECRETKEY")
sys.exit(1)
res = requests.request(**d1, headers=HEADERS)
try:
res.raise_for_status()
res = res.json()
validate(instance=res, schema=schema_)
except Exception as e:
print(f"❌ access_token 刷新失败: {e}")
sys.exit(1)
self.save_info(res)
return self