From 464b6af410e4d9ac2d55ce17921f89ab40e7ddff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=96=87=E8=96=87=E5=AE=89?= Date: Tue, 3 Feb 2026 10:55:11 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B4=A6=E5=8F=B7=E4=B8=8D=E5=88=87=E6=8D=A2?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/database/models.py | 47 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/backend/database/models.py b/backend/database/models.py index 62dc1aa..ebe47f5 100644 --- a/backend/database/models.py +++ b/backend/database/models.py @@ -61,6 +61,10 @@ class Account: logger.warning(f"Account.get: account_id={account_id} not found in database") return row + @staticmethod + def get_by_id(account_id: int): + return Account.get(account_id) + @staticmethod def list_all(): return db.execute_query("SELECT id, name, status, created_at, updated_at FROM accounts ORDER BY id ASC") @@ -78,6 +82,31 @@ class Account: ) return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] + @staticmethod + def update(account_id: int, name: str = None, status: str = None, testnet: int = None, api_key: str = None, api_secret: str = None): + """通用更新方法""" + # 如果涉及敏感字段,调用 update_credentials + if api_key is not None or api_secret is not None: + Account.update_credentials(account_id, api_key, api_secret, bool(testnet) if testnet is not None else None) + + # 更新普通字段 + fields = [] + values = [] + if name is not None: + fields.append("name = %s") + values.append(name) + if status is not None: + fields.append("status = %s") + values.append(status) + # 如果只有 testnet 而没有 key/secret,也需要更新 + if testnet is not None and api_key is None and api_secret is None: + fields.append("use_testnet = %s") + values.append(bool(testnet)) + + if fields: + values.append(int(account_id)) + db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(values)) + @staticmethod def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None): from security.crypto import encrypt_str # 延迟导入 @@ -177,6 +206,10 @@ class UserAccountMembership: (int(user_id), int(account_id), role), ) + @staticmethod + def add_membership(user_id: int, account_id: int, role: str = "viewer"): + return UserAccountMembership.add(user_id, account_id, role) + @staticmethod def remove(user_id: int, account_id: int): db.execute_update( @@ -191,6 +224,20 @@ class UserAccountMembership: (int(user_id),), ) + @staticmethod + def get_user_accounts(user_id: int): + """获取用户关联的账号列表(包含账号详情)""" + return db.execute_query( + """ + SELECT a.id, a.name, a.status, a.created_at, a.updated_at, m.role + FROM accounts a + JOIN user_account_memberships m ON a.id = m.account_id + WHERE m.user_id = %s + ORDER BY a.id ASC + """, + (int(user_id),) + ) + @staticmethod def list_for_account(account_id: int): return db.execute_query(