""" 数据库模型定义 """ from database.connection import db from datetime import datetime, timezone, timedelta import json import logging import os # 北京时间时区(UTC+8) BEIJING_TZ = timezone(timedelta(hours=8)) def get_beijing_time(): """获取当前北京时间(UTC+8)的Unix时间戳(秒)""" return int(datetime.now(BEIJING_TZ).timestamp()) logger = logging.getLogger(__name__) def _resolve_default_account_id() -> int: """ 默认账号ID: - trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id - backend:未传 account_id 时默认走 1(兼容单账号) """ for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"): v = (os.getenv(k, "") or "").strip() if v: try: return int(v) except Exception: continue return 1 DEFAULT_ACCOUNT_ID = _resolve_default_account_id() def _table_has_column(table: str, col: str) -> bool: try: db.execute_one(f"SELECT {col} FROM {table} LIMIT 1") return True except Exception: return False class Account: """ 账号模型(多账号) - API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config """ @staticmethod def get(account_id: int): import logging logger = logging.getLogger(__name__) logger.debug("Account.get called with account_id=%s", account_id) row = db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),)) if row: logger.debug("Account.get: found account_id=%s, name=%s, status=%s", account_id, row.get('name', 'N/A'), row.get('status', 'N/A')) else: logger.warning(f"Account.get: account_id={account_id} not found in database") return row @staticmethod def get_by_id(account_id: int): return Account.get(account_id) @staticmethod def list_all(): return db.execute_query("SELECT id, name, status, created_at, updated_at, api_key_enc, api_secret_enc, use_testnet FROM accounts ORDER BY id ASC") @staticmethod def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"): from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩 api_key_enc = encrypt_str(api_key or "") api_secret_enc = encrypt_str(api_secret or "") db.execute_update( """INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet) VALUES (%s, %s, %s, %s, %s)""", (name, status, api_key_enc, api_secret_enc, bool(use_testnet)), ) return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] @staticmethod def update(account_id: int, name: str = None, status: str = None, testnet: int = None, api_key: str = None, api_secret: str = None): """通用更新方法""" # 如果涉及敏感字段,调用 update_credentials if api_key is not None or api_secret is not None: Account.update_credentials(account_id, api_key, api_secret, bool(testnet) if testnet is not None else None) # 更新普通字段 fields = [] values = [] if name is not None: fields.append("name = %s") values.append(name) if status is not None: fields.append("status = %s") values.append(status) # 如果只有 testnet 而没有 key/secret,也需要更新 if testnet is not None and api_key is None and api_secret is None: fields.append("use_testnet = %s") values.append(bool(testnet)) if fields: values.append(int(account_id)) db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(values)) @staticmethod def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None): from security.crypto import encrypt_str # 延迟导入 fields = [] params = [] if api_key is not None: fields.append("api_key_enc = %s") params.append(encrypt_str(api_key)) if api_secret is not None: fields.append("api_secret_enc = %s") params.append(encrypt_str(api_secret)) if use_testnet is not None: fields.append("use_testnet = %s") params.append(bool(use_testnet)) if not fields: return params.append(int(account_id)) db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params)) @staticmethod def get_credentials(account_id: int): """ 返回 (api_key, api_secret, use_testnet, status);密文字段会自动解密。 若未配置 master key 且库里是明文,仍可工作(但不安全)。 """ import logging logger = logging.getLogger(__name__) logger.debug("Account.get_credentials called with account_id=%s", account_id) row = Account.get(account_id) if not row: logger.warning(f"Account.get_credentials: account_id={account_id} not found in database") return "", "", False, "disabled" try: from security.crypto import decrypt_str status = row.get("status") or "active" api_key = decrypt_str(row.get("api_key_enc") or "") api_secret = decrypt_str(row.get("api_secret_enc") or "") except Exception: # 兼容:无 cryptography 或未配 master key 时: # - 若库里是明文,仍可工作 # - 若库里是 enc:v1 密文但未配 ATS_MASTER_KEY,则不能解密,也不能把密文当作 Key 使用 status = "disabled" api_key_raw = row.get("api_key_enc") or "" api_secret_raw = row.get("api_secret_enc") or "" api_key = "" if str(api_key_raw).startswith("enc:v1:") else str(api_key_raw) api_secret = "" if str(api_secret_raw).startswith("enc:v1:") else str(api_secret_raw) use_testnet = bool(row.get("use_testnet") or False) return api_key, api_secret, use_testnet, status class User: """登录用户(管理员/普通用户)""" @staticmethod def get_by_username(username: str): return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),)) @staticmethod def get_by_id(user_id: int): return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),)) @staticmethod def list_all(): return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC") @staticmethod def create(username: str, password_hash: str, role: str = "user", status: str = "active"): db.execute_update( "INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)", (username, password_hash, role, status), ) return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"] @staticmethod def set_password(user_id: int, password_hash: str): db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id))) @staticmethod def set_status(user_id: int, status: str): db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id))) @staticmethod def set_role(user_id: int, role: str): db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id))) class UserAccountMembership: """用户-交易账号授权关系""" @staticmethod def add(user_id: int, account_id: int, role: str = "viewer"): db.execute_update( """INSERT INTO user_account_memberships (user_id, account_id, role) VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE role = VALUES(role)""", (int(user_id), int(account_id), role), ) @staticmethod def add_membership(user_id: int, account_id: int, role: str = "viewer"): return UserAccountMembership.add(user_id, account_id, role) @staticmethod def remove(user_id: int, account_id: int): db.execute_update( "DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) @staticmethod def clear_other_owners_for_account(account_id: int, keep_user_id: int): """每个账号仅允许一名 owner:将本账号下其他用户的 owner 降为 viewer。""" db.execute_update( "UPDATE user_account_memberships SET role = 'viewer' WHERE account_id = %s AND user_id != %s AND role = 'owner'", (int(account_id), int(keep_user_id)), ) @staticmethod def list_for_user(user_id: int): return db.execute_query( "SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC", (int(user_id),), ) @staticmethod def get_user_accounts(user_id: int): """获取用户关联的账号列表(包含账号详情)""" return db.execute_query( """ SELECT a.id, a.name, a.status, a.created_at, a.updated_at, a.api_key_enc, a.api_secret_enc, m.role FROM accounts a JOIN user_account_memberships m ON a.id = m.account_id WHERE m.user_id = %s ORDER BY a.id ASC """, (int(user_id),) ) @staticmethod def list_for_account(account_id: int): return db.execute_query( "SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC", (int(account_id),), ) @staticmethod def has_access(user_id: int, account_id: int) -> bool: row = db.execute_one( "SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) return bool(row) @staticmethod def get_role(user_id: int, account_id: int) -> str: row = db.execute_one( "SELECT role FROM user_account_memberships WHERE user_id = %s AND account_id = %s", (int(user_id), int(account_id)), ) return (row.get("role") if isinstance(row, dict) else None) or "" class TradingConfig: """交易配置模型""" @staticmethod def get_all(account_id: int = None): """获取所有配置""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): return db.execute_query( "SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key", (aid,), ) return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key") @staticmethod def get(key, account_id: int = None): """获取单个配置""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): return db.execute_one( "SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s", (aid, key), ) return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,)) @staticmethod def set(key, value, config_type, category, description=None, account_id: int = None): """设置配置""" value_str = TradingConfig._convert_to_string(value, config_type) aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trading_config", "account_id"): db.execute_update( """INSERT INTO trading_config (account_id, config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (aid, key, value_str, config_type, category, description), ) else: db.execute_update( """INSERT INTO trading_config (config_key, config_value, config_type, category, description) VALUES (%s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description), ) @staticmethod def get_value(key, default=None, account_id: int = None): """获取配置值(自动转换类型)""" result = TradingConfig.get(key, account_id=account_id) if result: return TradingConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型""" if config_type in ('number', 'float', 'int', 'integer'): try: return float(value) if '.' in str(value) else int(value) except: return 0 elif config_type in ('boolean', 'bool'): return str(value).lower() in ('true', '1', 'yes', 'on') elif config_type == 'json': try: return json.loads(value) except: return {} return value @staticmethod def _convert_to_string(value, config_type): """转换值为字符串""" if config_type == 'json': return json.dumps(value, ensure_ascii=False) return str(value) # 缓存键常量(与 market_cache 表配合使用) MARKET_CACHE_KEY_EXCHANGE_INFO = "exchange_info" MARKET_CACHE_KEY_FUNDING_INFO = "funding_info" class MarketCache: """ 市场数据缓存:交易对信息(exchange_info)、资金费率规则(funding_info) 等较固定内容入库, 减少对币安 API 的调用。trading_system 可优先从 DB 读取,过期或缺失时再拉 API 并回写。 """ @staticmethod def _table_exists(): try: db.execute_one("SELECT 1 FROM market_cache LIMIT 1") return True except Exception: return False @staticmethod def get(cache_key: str): """获取缓存:返回 dict { cache_value: 解析后的对象, updated_at } 或 None。""" if not MarketCache._table_exists(): return None try: row = db.execute_one( "SELECT cache_value, updated_at FROM market_cache WHERE cache_key = %s", (cache_key,), ) if not row: return None raw = row.get("cache_value") updated_at = row.get("updated_at") if raw is None: return None try: value = json.loads(raw) if isinstance(raw, str) else raw except Exception: value = raw return {"cache_value": value, "updated_at": updated_at} except Exception as e: logger.debug("MarketCache.get %s: %s", cache_key, e) return None @staticmethod def set(cache_key: str, value) -> bool: """写入缓存,value 可为 dict/list(将序列化为 JSON)。""" if not MarketCache._table_exists(): return False try: payload = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value db.execute_update( """INSERT INTO market_cache (cache_key, cache_value) VALUES (%s, %s) ON DUPLICATE KEY UPDATE cache_value = VALUES(cache_value), updated_at = CURRENT_TIMESTAMP""", (cache_key, payload), ) return True except Exception as e: logger.warning("MarketCache.set %s: %s", cache_key, e) return False @staticmethod def get_exchange_info(max_age_seconds: int = 86400): """ 获取缓存的 exchange_info(币安 GET /fapi/v1/exchangeInfo)。 若缓存存在且 updated_at 在 max_age_seconds 内,返回解析后的 exchange_info 字典; 否则返回 None(调用方应拉 API 并调用 set_exchange_info 回写)。 """ out = MarketCache.get(MARKET_CACHE_KEY_EXCHANGE_INFO) if not out: return None updated_at = out.get("updated_at") if max_age_seconds is not None and updated_at: try: from datetime import datetime, timezone, timedelta if isinstance(updated_at, (int, float)): utc_ts = float(updated_at) else: # datetime 转 UTC 时间戳 utc_ts = updated_at.timestamp() if hasattr(updated_at, "timestamp") else 0 age = datetime.now(timezone.utc).timestamp() - utc_ts if age > max_age_seconds: return None except Exception: pass return out.get("cache_value") @staticmethod def set_exchange_info(data: dict) -> bool: """写入 exchange_info 到 market_cache。""" return MarketCache.set(MARKET_CACHE_KEY_EXCHANGE_INFO, data) @staticmethod def get_funding_info(max_age_seconds: int = 86400): """ 获取缓存的 funding_info(币安 GET /fapi/v1/fundingInfo)。 若缓存存在且在 max_age_seconds 内,返回解析后的列表/字典;否则返回 None。 """ out = MarketCache.get(MARKET_CACHE_KEY_FUNDING_INFO) if not out: return None updated_at = out.get("updated_at") if max_age_seconds is not None and updated_at: try: from datetime import datetime, timezone if isinstance(updated_at, (int, float)): utc_ts = float(updated_at) else: utc_ts = updated_at.timestamp() if hasattr(updated_at, "timestamp") else 0 age = datetime.now(timezone.utc).timestamp() - utc_ts if age > max_age_seconds: return None except Exception: pass return out.get("cache_value") @staticmethod def set_funding_info(data) -> bool: """写入 funding_info 到 market_cache。""" return MarketCache.set(MARKET_CACHE_KEY_FUNDING_INFO, data) class GlobalStrategyConfig: """全局策略配置模型(独立于账户,管理员专用)""" @staticmethod def get_all(): """获取所有全局配置""" if not _table_has_column("global_strategy_config", "config_key"): return [] return db.execute_query( "SELECT * FROM global_strategy_config ORDER BY category, config_key" ) @staticmethod def get(key): """获取单个全局配置""" if not _table_has_column("global_strategy_config", "config_key"): return None return db.execute_one( "SELECT * FROM global_strategy_config WHERE config_key = %s", (key,) ) @staticmethod def set(key, value, config_type, category, description=None, updated_by=None): """设置全局配置""" if not _table_has_column("global_strategy_config", "config_key"): # 表不存在时,回退到trading_config(兼容旧系统) return TradingConfig.set(key, value, config_type, category, description, account_id=1) value_str = TradingConfig._convert_to_string(value, config_type) db.execute_update( """INSERT INTO global_strategy_config (config_key, config_value, config_type, category, description, updated_by) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE config_value = VALUES(config_value), config_type = VALUES(config_type), category = VALUES(category), description = VALUES(description), updated_by = VALUES(updated_by), updated_at = CURRENT_TIMESTAMP""", (key, value_str, config_type, category, description, updated_by), ) @staticmethod def get_value(key, default=None): """获取全局配置值(自动转换类型)""" result = GlobalStrategyConfig.get(key) if result: return GlobalStrategyConfig._convert_value(result['config_value'], result['config_type']) return default @staticmethod def _convert_value(value, config_type): """转换配置值类型(复用TradingConfig的逻辑)""" return TradingConfig._convert_value(value, config_type) @staticmethod def delete(key): """删除全局配置""" if not _table_has_column("global_strategy_config", "config_key"): return db.execute_update( "DELETE FROM global_strategy_config WHERE config_key = %s", (key,) ) class Trade: """交易记录模型""" @staticmethod def create( symbol, side, quantity, entry_price, leverage=10, entry_reason=None, entry_order_id=None, client_order_id=None, stop_loss_price=None, take_profit_price=None, take_profit_1=None, take_profit_2=None, atr=None, notional_usdt=None, margin_usdt=None, account_id: int = None, entry_context=None, status: str = "open", entry_time=None, ): """创建交易记录(使用北京时间) Args: entry_time: 开仓时间(Unix 时间戳秒,或 None 表示当前北京时间)。补建/同步时建议从币安订单或成交取真实时间。 symbol: 交易对 status: 状态,默认 "open";先落库等 WS 成交时可传 "pending" side: 方向 quantity: 数量 entry_price: 入场价 leverage: 杠杆 entry_reason: 入场原因(简短文本) entry_order_id: 币安开仓订单号(可选,用于对账) client_order_id: 币安自定义订单号 clientOrderId(可选,用于在订单记录中核对系统单) stop_loss_price: 实际使用的止损价格(考虑了ATR等动态计算) take_profit_price: 实际使用的止盈价格(考虑了ATR等动态计算) take_profit_1: 第一目标止盈价(可选) take_profit_2: 第二目标止盈价(可选) atr: 开仓时使用的ATR值(可选) notional_usdt: 名义下单量(USDT,可选) margin_usdt: 保证金(USDT,可选) entry_context: 入场思路/过程(dict,将存为 JSON):信号强度、市场状态、趋势、过滤通过情况等,便于事后分析策略执行效果 """ if entry_time is not None: try: if hasattr(entry_time, "timestamp"): entry_time = int(entry_time.timestamp()) else: entry_time = int(float(entry_time)) except (TypeError, ValueError): entry_time = get_beijing_time() else: entry_time = get_beijing_time() # 自动计算 notional/margin(若调用方没传) try: if notional_usdt is None and quantity is not None and entry_price is not None: notional_usdt = float(quantity) * float(entry_price) except Exception: pass try: if margin_usdt is None and notional_usdt is not None: lv = float(leverage) if leverage else 0 margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt) except Exception: pass def _has_column(col: str) -> bool: try: db.execute_one(f"SELECT {col} FROM trades LIMIT 1") return True except Exception: return False # 动态构建 INSERT(兼容不同schema) columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"] values = [symbol, side, quantity, entry_price, leverage, entry_reason, (status or "open"), entry_time] # 在真正执行 INSERT 之前,利用 entry_order_id / client_order_id 做一次幂等去重 try: if entry_order_id is not None and _has_column("entry_order_id"): try: existing = Trade.get_by_entry_order_id(entry_order_id) except Exception: existing = None if existing: try: logger.debug( f"Trade.create: entry_order_id={entry_order_id} 已存在 (id={existing.get('id')}, " f"symbol={existing.get('symbol')}, status={existing.get('status')}),直接复用" ) except Exception: pass return existing.get("id") if client_order_id and _has_column("client_order_id"): try: existing = Trade.get_by_client_order_id(client_order_id) except Exception: existing = None if existing: try: logger.debug( f"Trade.create: client_order_id={client_order_id!r} 已存在 (id={existing.get('id')}, " f"symbol={existing.get('symbol')}, status={existing.get('status')}),直接复用" ) except Exception: pass return existing.get("id") except Exception: # 去重失败不影响后续正常插入,由数据库唯一约束兜底 pass if _has_column("account_id"): columns.insert(0, "account_id") values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID)) if _has_column("entry_order_id"): columns.append("entry_order_id") values.append(entry_order_id) if _has_column("client_order_id"): columns.append("client_order_id") values.append(client_order_id) if _has_column("notional_usdt"): columns.append("notional_usdt") values.append(notional_usdt) if _has_column("margin_usdt"): columns.append("margin_usdt") values.append(margin_usdt) if _has_column("atr"): columns.append("atr") values.append(atr) if _has_column("stop_loss_price"): columns.append("stop_loss_price") values.append(stop_loss_price) if _has_column("take_profit_price"): columns.append("take_profit_price") values.append(take_profit_price) if _has_column("take_profit_1"): columns.append("take_profit_1") values.append(take_profit_1) if _has_column("take_profit_2"): columns.append("take_profit_2") values.append(take_profit_2) if _has_column("entry_context") and entry_context is not None: try: entry_context_str = json.dumps(entry_context, ensure_ascii=False) if isinstance(entry_context, dict) else str(entry_context) except Exception: entry_context_str = None if entry_context_str is not None: columns.append("entry_context") values.append(entry_context_str) placeholders = ", ".join(["%s"] * len(columns)) sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})" db.execute_update(sql, tuple(values)) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def update_exit( trade_id, exit_price, exit_reason, pnl, pnl_percent, exit_order_id=None, strategy_type=None, duration_minutes=None, exit_time_ts=None, realized_pnl=None, commission=None, commission_asset=None, ): """更新平仓信息(使用北京时间) Args: trade_id: 交易记录ID exit_price: 出场价 exit_reason: 平仓原因 pnl: 盈亏 pnl_percent: 盈亏百分比 exit_order_id: 币安平仓订单号(可选,用于对账) realized_pnl: 币安实际结算盈亏(可选) commission: 交易手续费(可选) commission_asset: 手续费币种(可选) 注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突 """ # exit_time_ts: 允许外部传入“真实成交时间”(Unix秒)以便统计持仓时长更准确 try: exit_time = int(exit_time_ts) if exit_time_ts is not None else get_beijing_time() except Exception: exit_time = get_beijing_time() # 准备额外字段更新 helper def _append_extra_fields(fields, values): if strategy_type is not None: fields.append("strategy_type = %s") values.append(strategy_type) if duration_minutes is not None: fields.append("duration_minutes = %s") values.append(duration_minutes) # 新增字段(检查是否存在) if realized_pnl is not None and _table_has_column("trades", "realized_pnl"): fields.append("realized_pnl = %s") values.append(realized_pnl) if commission is not None and _table_has_column("trades", "commission"): fields.append("commission = %s") values.append(commission) if commission_asset is not None and _table_has_column("trades", "commission_asset"): fields.append("commission_asset = %s") values.append(commission_asset) # 如果提供了 exit_order_id,先检查是否已被其他交易记录使用 if exit_order_id is not None: try: existing_trade = Trade.get_by_exit_order_id(exit_order_id) if existing_trade: if existing_trade['id'] == trade_id: # exit_order_id 属于当前交易记录:允许继续更新 logger.debug( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,将继续更新其他字段" ) else: # 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id logger.warning( f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用," f"跳过更新 exit_order_id,只更新其他字段" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] _append_extra_fields(update_fields, update_values) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) return except Exception as e: # 如果查询失败,记录警告但继续正常更新 logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新") # 正常更新(包括 exit_order_id) try: update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'", "exit_order_id = %s" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id] _append_extra_fields(update_fields, update_values) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) except Exception as e: # 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id error_str = str(e) if "Duplicate entry" in error_str and "exit_order_id" in error_str: logger.warning( f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突," f"尝试不更新 exit_order_id" ) # 只更新其他字段,不更新 exit_order_id update_fields = [ "exit_price = %s", "exit_time = %s", "exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'" ] update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent] if strategy_type is not None: update_fields.append("strategy_type = %s") update_values.append(strategy_type) if duration_minutes is not None: update_fields.append("duration_minutes = %s") update_values.append(duration_minutes) update_values.append(trade_id) db.execute_update( f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s", tuple(update_values) ) else: # 其他错误,重新抛出 raise @staticmethod def update_open_fields(trade_id: int, **kwargs): """开仓后完善字段:止损/止盈/名义/保证金/entry_context 等(用于先 pending 落库、成交后补全)。""" if not kwargs: return allowed = { "stop_loss_price", "take_profit_price", "take_profit_1", "take_profit_2", "notional_usdt", "margin_usdt", "entry_context", "atr" } updates = [] values = [] for k, v in kwargs.items(): if k not in allowed or v is None: continue if k == "entry_context": try: v = json.dumps(v, ensure_ascii=False) if isinstance(v, dict) else str(v) except Exception: continue updates.append(f"{k} = %s") values.append(v) if not updates: return values.append(trade_id) try: db.execute_update( f"UPDATE trades SET {', '.join(updates)} WHERE id = %s", tuple(values) ) except Exception as e: logger.error(f"[DB] update_open_fields trade_id={trade_id} 失败 (SL/TP 等未写入): {e}") @staticmethod def update_status(trade_id, status: str) -> bool: """更新交易记录状态(如下单失败时将 pending 标为 cancelled)。""" if not trade_id or not status: return False try: db.execute_update("UPDATE trades SET status = %s WHERE id = %s", (str(status).strip(), int(trade_id))) return True except Exception as e: logger.error(f"[DB] update_status trade_id={trade_id} status={status} 失败: {e}") return False @staticmethod def get_by_entry_order_id(entry_order_id): """根据开仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE entry_order_id = %s", (entry_order_id,) ) @staticmethod def get_by_exit_order_id(exit_order_id): """根据平仓订单号获取交易记录""" return db.execute_one( "SELECT * FROM trades WHERE exit_order_id = %s", (exit_order_id,) ) @staticmethod def get_trades_missing_entry_order_id(symbol: str, account_id: int, limit: int = 50): """ 获取该 symbol+account 下缺少 entry_order_id 的记录(用于从币安同步订单时补全)。 包含 entry_time 为 NULL 的旧记录,避免「仅可对账」下看不到今日有单却未补全的情况。 """ if not symbol or account_id is None: return [] try: if not _table_has_column("trades", "entry_order_id"): return [] if _table_has_column("trades", "account_id"): return db.execute_query( """SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND (entry_order_id IS NULL OR entry_order_id = 0 OR entry_order_id = '') ORDER BY id DESC LIMIT %s""", (int(account_id), symbol.strip(), int(limit)), ) return db.execute_query( """SELECT * FROM trades WHERE symbol = %s AND (entry_order_id IS NULL OR entry_order_id = 0 OR entry_order_id = '') ORDER BY id DESC LIMIT %s""", (symbol.strip(), int(limit)), ) except Exception as e: logger.debug(f"get_trades_missing_entry_order_id 失败: {e}") return [] @staticmethod def get_pending_recent(account_id: int, limit: int = 50, max_age_sec: int = 86400): """ 获取最近时间范围内 status=pending 的记录,用于正向流程 pending 对账。 返回有 client_order_id 或 entry_order_id 的记录,按 id 降序。 """ if account_id is None: return [] try: if not _table_has_column("trades", "account_id"): return [] if not _table_has_column("trades", "client_order_id") and not _table_has_column("trades", "entry_order_id"): return [] import time cutoff_ts = int(time.time()) - max(1, int(max_age_sec)) # 用 created_at 或 entry_time 作为时间过滤(created_at 更准确表示插入时间) time_col = "created_at" if _table_has_column("trades", "created_at") else "entry_time" query = f""" SELECT * FROM trades WHERE account_id = %s AND status = 'pending' AND {time_col} >= %s AND (client_order_id IS NOT NULL AND client_order_id != '' OR entry_order_id IS NOT NULL AND entry_order_id != 0 AND entry_order_id != '') ORDER BY id DESC LIMIT %s """ return db.execute_query(query, (int(account_id), cutoff_ts, int(limit))) except Exception as e: logger.debug(f"get_pending_recent 失败: {e}") return [] @staticmethod def get_by_client_order_id(client_order_id, account_id: int = None): """根据 clientOrderId 获取交易记录(可选按 account_id 隔离)""" if not client_order_id: return None try: if account_id is not None and _table_has_column("trades", "account_id"): return db.execute_one( "SELECT * FROM trades WHERE client_order_id = %s AND account_id = %s", (client_order_id, int(account_id)) ) return db.execute_one( "SELECT * FROM trades WHERE client_order_id = %s", (client_order_id,) ) except Exception: return None @staticmethod def update_pending_to_filled(client_order_id, account_id: int, entry_order_id, entry_price: float, quantity: float): """WS 或 REST 收到成交后:按 client_order_id 将 pending 记录完善为已成交(幂等)""" if not client_order_id or account_id is None: return False try: row = Trade.get_by_client_order_id(client_order_id, account_id) if not row: return False # 仅当仍为 pending 或 entry_order_id 为空时更新,避免覆盖已完善数据 if row.get("status") not in ("pending", "open") and row.get("entry_order_id"): return True # 已完善,视为成功 db.execute_update( """UPDATE trades SET entry_order_id = %s, entry_price = %s, quantity = %s, status = 'open' WHERE client_order_id = %s AND account_id = %s""", (entry_order_id, float(entry_price), float(quantity), client_order_id, int(account_id)) ) return True except Exception as e: logger.error(f"[DB] update_pending_to_filled 失败 client_order_id={client_order_id!r} account_id={account_id}: {e}") return False @staticmethod def update_pending_by_entry_order_id(symbol: str, account_id: int, entry_order_id, entry_price: float, quantity: float): """ UDS 开仓 FILLED 但无 clientOrderId 时的兜底:用 orderId 完善一条 pending 记录。 若 DB 中已有该 entry_order_id 则跳过;否则在该 symbol+account 下找「status 为 pending 且 entry_order_id 为空」的记录, 若恰好 1 条则用 orderId/价格/数量更新(避免误匹配多笔 pending 时只更新一条)。 """ if not symbol or account_id is None or entry_order_id is None: return False try: existing = Trade.get_by_entry_order_id(entry_order_id) if existing: return True # 已存在该订单号,无需兜底 if not _table_has_column("trades", "account_id"): return False # 查该 symbol+account 下 pending 且无 entry_order_id 的记录 rows = db.execute_query( """SELECT id FROM trades WHERE account_id = %s AND symbol = %s AND status = 'pending' AND (entry_order_id IS NULL OR entry_order_id = 0) ORDER BY id DESC""", (int(account_id), symbol.strip()) ) if not rows or len(rows) != 1: return False tid = rows[0]["id"] db.execute_update( """UPDATE trades SET entry_order_id = %s, entry_price = %s, quantity = %s, status = 'open' WHERE id = %s""", (str(entry_order_id), float(entry_price), float(quantity), int(tid)) ) logger.info(f"Trade.update_pending_by_entry_order_id: 已用 orderId={entry_order_id} 兜底完善 pending 记录 id={tid} symbol={symbol!r}") return True except Exception as e: logger.warning(f"update_pending_by_entry_order_id 失败 symbol={symbol!r} orderId={entry_order_id}: {e}") return False @staticmethod def update_entry_order_id(trade_id: int, entry_order_id): """补全或更新开仓订单号(用于 REST 兜底补全)""" if not trade_id or not entry_order_id: return False try: db.execute_update( """UPDATE trades SET entry_order_id = %s WHERE id = %s AND (entry_order_id IS NULL OR entry_order_id = '')""", (str(entry_order_id), int(trade_id)) ) return True except Exception as e: logger.warning(f"update_entry_order_id 失败 trade_id={trade_id}: {e}") return False @staticmethod def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None, time_filter: str = "exit", limit: int = None, reconciled_only: bool = False, include_sync: bool = True): """ 获取交易记录。 time_filter: 时间范围按哪种时间筛选 - "exit": 按平仓时间(已平仓用 exit_time,未平仓用 entry_time)。选「今天」= 今天平掉的单 + 今天开的未平仓,更符合直觉。 - "entry": 按开仓时间(实际入场时间,适合策略分析:何时进场、持仓时长)。 - "created": 按创建时间(记录写入 DB 的时间,略早于或等于成交时间;无 created_at 时回退为 entry)。 - "both": 原逻辑,COALESCE(exit_time, entry_time)。 limit: 最多返回条数,None 表示不限制。 reconciled_only: 仅可对账(有 entry_order_id,已平仓的还有 exit_order_id),在 SQL 中过滤以减轻负载。 include_sync: 是否包含 entry_reason=sync_recovered / exit_reason=sync 的记录,在 SQL 中过滤。 """ query = "SELECT * FROM trades WHERE 1=1" params = [] # 多账号隔离(兼容旧schema) try: if _table_has_column("trades", "account_id"): query += " AND account_id = %s" params.append(int(account_id or DEFAULT_ACCOUNT_ID)) except Exception: pass if reconciled_only and _table_has_column("trades", "entry_order_id"): query += " AND entry_order_id IS NOT NULL AND entry_order_id != 0" query += " AND (status != 'closed' OR (exit_order_id IS NOT NULL AND exit_order_id != 0))" if include_sync is False: query += " AND (entry_reason IS NULL OR entry_reason != 'sync_recovered')" query += " AND (exit_reason IS NULL OR exit_reason != 'sync')" # 按创建时间筛选:有 created_at 列则用 COALESCE(created_at, entry_time);无则回退为 entry_time,保证「按创建时间」有结果 use_created = (time_filter == "created" and _table_has_column("trades", "created_at")) time_col = "COALESCE(created_at, entry_time)" if use_created else None if start_timestamp is not None and end_timestamp is not None: if time_filter == "exit": query += " AND ((status = 'closed' AND exit_time >= %s AND exit_time <= %s) OR (status != 'closed' AND entry_time >= %s AND entry_time <= %s))" params.extend([start_timestamp, end_timestamp, start_timestamp, end_timestamp]) elif time_filter == "entry": query += " AND entry_time >= %s AND entry_time <= %s" params.extend([start_timestamp, end_timestamp]) elif use_created: query += " AND " + time_col + " >= %s AND " + time_col + " <= %s" params.extend([start_timestamp, end_timestamp]) elif time_filter == "created": query += " AND entry_time >= %s AND entry_time <= %s" params.extend([start_timestamp, end_timestamp]) else: query += " AND COALESCE(exit_time, entry_time) >= %s AND COALESCE(exit_time, entry_time) <= %s" params.extend([start_timestamp, end_timestamp]) elif start_timestamp is not None: if time_filter == "exit": query += " AND ((status = 'closed' AND exit_time >= %s) OR (status != 'closed' AND entry_time >= %s))" params.extend([start_timestamp, start_timestamp]) elif use_created: query += " AND " + time_col + " >= %s" params.append(start_timestamp) elif time_filter == "entry" or time_filter == "created": query += " AND entry_time >= %s" params.append(start_timestamp) else: query += " AND COALESCE(exit_time, entry_time) >= %s" params.append(start_timestamp) elif end_timestamp is not None: if time_filter == "exit": query += " AND ((status = 'closed' AND exit_time <= %s) OR (status != 'closed' AND entry_time <= %s))" params.extend([end_timestamp, end_timestamp]) elif use_created: query += " AND " + time_col + " <= %s" params.append(end_timestamp) elif time_filter == "entry" or time_filter == "created": query += " AND entry_time <= %s" params.append(end_timestamp) else: query += " AND COALESCE(exit_time, entry_time) <= %s" params.append(end_timestamp) if symbol: query += " AND symbol = %s" params.append(symbol) if status: query += " AND status = %s" params.append(status) if trade_type: query += " AND side = %s" params.append(trade_type) if exit_reason: query += " AND exit_reason = %s" params.append(exit_reason) if use_created: query += " ORDER BY " + time_col + " DESC, id DESC" elif time_filter == "created": query += " ORDER BY entry_time DESC, id DESC" else: query += " ORDER BY COALESCE(exit_time, entry_time) DESC, id DESC" # 未传 limit 时使用默认上限,防止全表加载导致内存暴增(2 CPU 4G 场景) _limit = limit if _limit is None or _limit <= 0: _limit = 10000 query += " LIMIT %s" params.append(int(_limit)) logger.debug(f"查询交易记录: time_filter={time_filter}, limit={_limit}, reconciled_only={reconciled_only}, include_sync={include_sync}") result = db.execute_query(query, params) return result @staticmethod def get_by_symbol(symbol, status='open', account_id: int = None): """根据交易对获取持仓""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("trades", "account_id"): return db.execute_query( "SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s", (aid, symbol, status), ) return db.execute_query( "SELECT * FROM trades WHERE symbol = %s AND status = %s", (symbol, status), ) @staticmethod def set_exit_order_id_for_open_trade(symbol: str, account_id: int, exit_order_id, entry_order_id: int = None): """ ALGO_UPDATE/ORDER_TRADE_UPDATE 平仓成交:为指定 symbol 下未填 exit_order_id 的 open 记录补全平仓订单号。 优先按 entry_order_id 精确匹配,若无则按 symbol 匹配最早的一条 open 记录。 Returns: (True, trade_id) 成功并返回该记录 id,便于后续 update_exit;(False, None) 失败或未匹配。 """ if not symbol or account_id is None or exit_order_id is None: return False, None try: if not _table_has_column("trades", "account_id"): return False, None # 优先按 entry_order_id 精确匹配(如果提供了 entry_order_id) if entry_order_id: row = db.execute_one( """SELECT id FROM trades WHERE account_id = %s AND symbol = %s AND status = 'open' AND entry_order_id = %s AND (exit_order_id IS NULL OR exit_order_id = '' OR exit_order_id = '0') LIMIT 1""", (int(account_id), symbol.strip(), str(entry_order_id)) ) if row: db.execute_update( """UPDATE trades SET exit_order_id = %s WHERE id = %s""", (str(exit_order_id), row["id"]) ) logger.debug(f"set_exit_order_id_for_open_trade: 按 entry_order_id={entry_order_id} 精确匹配成功 id={row['id']}") return True, row["id"] # 否则按 symbol 匹配最早的一条 open 记录(按 entry_time 排序) row = db.execute_one( """SELECT id FROM trades WHERE account_id = %s AND symbol = %s AND status = 'open' AND (exit_order_id IS NULL OR exit_order_id = '' OR exit_order_id = '0') ORDER BY entry_time ASC LIMIT 1""", (int(account_id), symbol.strip()) ) if not row: return False, None db.execute_update( """UPDATE trades SET exit_order_id = %s WHERE id = %s""", (str(exit_order_id), row["id"]) ) return True, row["id"] except Exception as e: logger.warning(f"set_exit_order_id_for_open_trade 失败 symbol={symbol!r}: {e}") return False, None class TradeStats: """ 交易统计:按交易对+日期、按小时聚合,写入 trade_stats_daily / trade_stats_time_bucket。 数据源优先 binance_trades(定时同步的币安成交),无表或无数据时回退到 trades。 """ @staticmethod def _binance_trades_exists(): try: db.execute_one("SELECT 1 FROM binance_trades LIMIT 1") return True except Exception: return False @staticmethod def _ensure_tables(): db.execute_update( """ CREATE TABLE IF NOT EXISTS trade_stats_daily ( id BIGINT PRIMARY KEY AUTO_INCREMENT, account_id INT NOT NULL, trade_date DATE NOT NULL, symbol VARCHAR(50) NOT NULL, trade_count INT NOT NULL, win_count INT NOT NULL, loss_count INT NOT NULL, gross_pnl DECIMAL(20,8) NOT NULL, net_pnl DECIMAL(20,8) NOT NULL, total_commission DECIMAL(20,8) NOT NULL, avg_pnl_per_trade DECIMAL(20,8) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY uniq_account_date_symbol (account_id, trade_date, symbol), KEY idx_trade_date (trade_date), KEY idx_symbol (symbol) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ ) db.execute_update( """ CREATE TABLE IF NOT EXISTS trade_stats_time_bucket ( id BIGINT PRIMARY KEY AUTO_INCREMENT, account_id INT NOT NULL, trade_date DATE NOT NULL, hour TINYINT NOT NULL, trade_count INT NOT NULL, win_count INT NOT NULL, loss_count INT NOT NULL, gross_pnl DECIMAL(20,8) NOT NULL, net_pnl DECIMAL(20,8) NOT NULL, total_commission DECIMAL(20,8) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, UNIQUE KEY uniq_account_date_hour (account_id, trade_date, hour), KEY idx_trade_date (trade_date) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4; """ ) @staticmethod def _aggregate_from_binance_trades(aid: int, start_ts: int, end_ts: int): """从 binance_trades 聚合,返回 (daily, hourly) 或 (None, None)。""" from datetime import datetime, timezone start_ms = start_ts * 1000 end_ms = end_ts * 1000 try: rows = db.execute_query( """SELECT symbol, trade_time, realized_pnl, commission FROM binance_trades WHERE account_id = %s AND trade_time >= %s AND trade_time <= %s LIMIT 100000""", (aid, start_ms, end_ms), ) except Exception as e: logger.debug(f"[TradeStats] 查询 binance_trades 失败: {e}") return None, None if not rows: return None, None def to_date_hour(ts_ms): try: dt = datetime.fromtimestamp(int(ts_ms) // 1000, timezone.utc).astimezone(BEIJING_TZ) return dt.date(), dt.hour except Exception: return None, None daily, hourly = {}, {} for r in rows: try: tm = r.get("trade_time") if tm is None: continue date, hour = to_date_hour(tm) if date is None: continue sym = (r.get("symbol") or "").strip() if not sym: continue pnl = float(r.get("realized_pnl") or 0) comm = float(r.get("commission") or 0) dkey, hkey = (date, sym), (date, hour) if dkey not in daily: daily[dkey] = {"trade_count": 0, "win_count": 0, "loss_count": 0, "gross_pnl": 0.0, "total_commission": 0.0} daily[dkey]["trade_count"] += 1 if pnl > 0: daily[dkey]["win_count"] += 1 elif pnl < 0: daily[dkey]["loss_count"] += 1 daily[dkey]["gross_pnl"] += pnl daily[dkey]["total_commission"] += comm if hkey not in hourly: hourly[hkey] = {"trade_count": 0, "win_count": 0, "loss_count": 0, "gross_pnl": 0.0, "total_commission": 0.0} hourly[hkey]["trade_count"] += 1 if pnl > 0: hourly[hkey]["win_count"] += 1 elif pnl < 0: hourly[hkey]["loss_count"] += 1 hourly[hkey]["gross_pnl"] += pnl hourly[hkey]["total_commission"] += comm except Exception as e: logger.debug(f"[TradeStats] 处理行失败: {e}") return daily, hourly @staticmethod def _aggregate_from_trades(aid: int, start_ts: int, end_ts: int): """从 trades 表聚合,返回 (daily, hourly)。""" from datetime import datetime, timezone try: rows = Trade.get_all( start_timestamp=start_ts, end_timestamp=end_ts, account_id=aid, time_filter="exit", limit=100000, reconciled_only=False, ) except Exception as e: logger.debug(f"[TradeStats] Trade.get_all 失败: {e}") return {}, {} if not rows: return {}, {} def to_date_hour(ts): try: dt = datetime.fromtimestamp(int(ts), timezone.utc).astimezone(BEIJING_TZ) return dt.date(), dt.hour except Exception: return None, None daily, hourly = {}, {} for r in rows: try: ts = r.get("exit_time") or r.get("entry_time") if not ts: continue date, hour = to_date_hour(ts) if date is None: continue sym = (r.get("symbol") or "").strip() if not sym: continue pnl = float(r.get("pnl") or 0) comm = float(r.get("commission") or 0) if "commission" in r else 0 dkey, hkey = (date, sym), (date, hour) if dkey not in daily: daily[dkey] = {"trade_count": 0, "win_count": 0, "loss_count": 0, "gross_pnl": 0.0, "total_commission": 0.0} daily[dkey]["trade_count"] += 1 if pnl > 0: daily[dkey]["win_count"] += 1 elif pnl < 0: daily[dkey]["loss_count"] += 1 daily[dkey]["gross_pnl"] += pnl daily[dkey]["total_commission"] += comm if hkey not in hourly: hourly[hkey] = {"trade_count": 0, "win_count": 0, "loss_count": 0, "gross_pnl": 0.0, "total_commission": 0.0} hourly[hkey]["trade_count"] += 1 if pnl > 0: hourly[hkey]["win_count"] += 1 elif pnl < 0: hourly[hkey]["loss_count"] += 1 hourly[hkey]["gross_pnl"] += pnl hourly[hkey]["total_commission"] += comm except Exception as e: logger.debug(f"[TradeStats] 处理 trades 行失败: {e}") return daily, hourly @staticmethod def _upsert_stats(aid: int, daily: dict, hourly: dict): """把 daily/hourly 聚合结果写入 trade_stats_daily / trade_stats_time_bucket。""" for (trade_date, symbol), v in daily.items(): tc = v["trade_count"] net = v["gross_pnl"] - v["total_commission"] avg = (net / tc) if tc > 0 else 0.0 try: db.execute_update( """INSERT INTO trade_stats_daily ( account_id, trade_date, symbol, trade_count, win_count, loss_count, gross_pnl, net_pnl, total_commission, avg_pnl_per_trade ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE trade_count = VALUES(trade_count), win_count = VALUES(win_count), loss_count = VALUES(loss_count), gross_pnl = VALUES(gross_pnl), net_pnl = VALUES(net_pnl), total_commission = VALUES(total_commission), avg_pnl_per_trade = VALUES(avg_pnl_per_trade)""", (aid, trade_date, symbol, tc, v["win_count"], v["loss_count"], v["gross_pnl"], net, v["total_commission"], avg), ) except Exception as e: logger.warning(f"[TradeStats] 写入 daily 失败 {symbol} {trade_date}: {e}") for (trade_date, hour), v in hourly.items(): net = v["gross_pnl"] - v["total_commission"] try: db.execute_update( """INSERT INTO trade_stats_time_bucket ( account_id, trade_date, hour, trade_count, win_count, loss_count, gross_pnl, net_pnl, total_commission ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE trade_count = VALUES(trade_count), win_count = VALUES(win_count), loss_count = VALUES(loss_count), gross_pnl = VALUES(gross_pnl), net_pnl = VALUES(net_pnl), total_commission = VALUES(total_commission)""", (aid, trade_date, int(hour), v["trade_count"], v["win_count"], v["loss_count"], v["gross_pnl"], net, v["total_commission"]), ) except Exception as e: logger.warning(f"[TradeStats] 写入 time_bucket 失败 {trade_date} h={hour}: {e}") @staticmethod def aggregate_recent_days(days: int = 7, account_id: int = None): """聚合最近 N 天到统计表。优先 binance_trades,无数据则用 trades。""" if days <= 0: return TradeStats._ensure_tables() aid = int(account_id or DEFAULT_ACCOUNT_ID) now_ts = get_beijing_time() start_ts = now_ts - int(days) * 86400 daily, hourly = None, None if TradeStats._binance_trades_exists(): daily, hourly = TradeStats._aggregate_from_binance_trades(aid, start_ts, now_ts) if daily is None or hourly is None: daily, hourly = TradeStats._aggregate_from_trades(aid, start_ts, now_ts) if daily or hourly: TradeStats._upsert_stats(aid, daily or {}, hourly or {}) @staticmethod def get_daily_stats(account_id: int = None, days: int = 7): """查询最近 N 天按交易对聚合统计,供 API/仪表盘展示。表不存在或无数据时返回 []。""" aid = int(account_id or DEFAULT_ACCOUNT_ID) try: rows = db.execute_query( """SELECT trade_date, symbol, trade_count, win_count, loss_count, gross_pnl, net_pnl, total_commission, avg_pnl_per_trade FROM trade_stats_daily WHERE account_id = %s AND trade_date >= DATE_SUB(CURDATE(), INTERVAL %s DAY) ORDER BY trade_date DESC, net_pnl DESC LIMIT 500""", (aid, max(1, int(days))), ) if not rows: return [] return [dict(r) for r in rows] except Exception as e: logger.debug(f"[TradeStats] get_daily_stats 失败: {e}") return [] @staticmethod def get_hourly_stats(account_id: int = None, days: int = 7): """查询最近 N 天按小时聚合统计,供 API/仪表盘展示。表不存在或无数据时返回 []。""" aid = int(account_id or DEFAULT_ACCOUNT_ID) try: rows = db.execute_query( """SELECT trade_date, hour, trade_count, win_count, loss_count, gross_pnl, net_pnl, total_commission FROM trade_stats_time_bucket WHERE account_id = %s AND trade_date >= DATE_SUB(CURDATE(), INTERVAL %s DAY) ORDER BY trade_date DESC, hour ASC LIMIT 500""", (aid, max(1, int(days))), ) if not rows: return [] return [dict(r) for r in rows] except Exception as e: logger.debug(f"[TradeStats] get_hourly_stats 失败: {e}") return [] class AccountSnapshot: """账户快照模型""" @staticmethod def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None): """创建账户快照(使用北京时间)""" snapshot_time = get_beijing_time() if _table_has_column("account_snapshots", "account_id"): db.execute_update( """INSERT INTO account_snapshots (account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s, %s)""", (int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), ) else: db.execute_update( """INSERT INTO account_snapshots (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time) VALUES (%s, %s, %s, %s, %s, %s)""", (total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time), ) @staticmethod def get_recent(days=7, account_id: int = None): """获取最近的快照""" aid = int(account_id or DEFAULT_ACCOUNT_ID) if _table_has_column("account_snapshots", "account_id"): return db.execute_query( """SELECT * FROM account_snapshots WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (aid, days), ) return db.execute_query( """SELECT * FROM account_snapshots WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY snapshot_time DESC""", (days,), ) class MarketScan: """市场扫描记录模型""" @staticmethod def create(symbols_scanned, symbols_found, top_symbols, scan_duration): """创建扫描记录(使用北京时间)""" scan_time = get_beijing_time() db.execute_update( """INSERT INTO market_scans (symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time) VALUES (%s, %s, %s, %s, %s)""", (symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time) ) @staticmethod def get_recent(limit=100): """获取最近的扫描记录""" return db.execute_query( "SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s", (limit,) ) class TradingSignal: """交易信号模型""" @staticmethod def create(symbol, signal_direction, signal_strength, signal_reason, rsi=None, macd_histogram=None, market_regime=None): """创建交易信号(使用北京时间)""" signal_time = get_beijing_time() db.execute_update( """INSERT INTO trading_signals (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, signal_direction, signal_strength, signal_reason, rsi, macd_histogram, market_regime, signal_time) ) @staticmethod def mark_executed(signal_id): """标记信号已执行""" db.execute_update( "UPDATE trading_signals SET executed = TRUE WHERE id = %s", (signal_id,) ) @staticmethod def get_recent(limit=100): """获取最近的信号""" return db.execute_query( "SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s", (limit,) ) class TradeRecommendation: """推荐交易对模型""" @staticmethod def create( symbol, direction, current_price, change_percent, recommendation_reason, signal_strength, market_regime=None, trend_4h=None, rsi=None, macd_histogram=None, bollinger_upper=None, bollinger_middle=None, bollinger_lower=None, ema20=None, ema50=None, ema20_4h=None, atr=None, suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None, suggested_position_percent=None, suggested_leverage=10, order_type='LIMIT', suggested_limit_price=None, volume_24h=None, volatility=None, notes=None, user_guide=None, recommendation_category=None, risk_level=None, expected_hold_time=None, trading_tutorial=None, max_hold_days=3 ): """创建推荐记录(使用北京时间)""" recommendation_time = get_beijing_time() # 默认24小时后过期 expires_at = recommendation_time + timedelta(hours=24) # 检查字段是否存在(兼容旧数据库schema) try: db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1") has_order_fields = True except: has_order_fields = False # 检查是否有 user_guide 字段 try: db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1") has_user_guide_fields = True except: has_user_guide_fields = False if has_user_guide_fields: # 包含所有新字段(user_guide等) db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes, user_guide, recommendation_category, risk_level, expected_hold_time, trading_tutorial, max_hold_days) ) elif has_order_fields: # 只有 order_type 字段,没有 user_guide 字段 db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, order_type, suggested_limit_price, volume_24h, volatility, expires_at, notes) ) else: # 兼容旧schema db.execute_update( """INSERT INTO trade_recommendations (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""", (symbol, direction, recommendation_time, current_price, change_percent, recommendation_reason, signal_strength, market_regime, trend_4h, rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower, ema20, ema50, ema20_4h, atr, suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2, suggested_position_percent, suggested_leverage, volume_24h, volatility, expires_at, notes) ) return db.execute_one("SELECT LAST_INSERT_ID() as id")['id'] @staticmethod def mark_executed(recommendation_id, trade_id=None, execution_result='success'): """标记推荐已执行""" executed_at = get_beijing_time() db.execute_update( """UPDATE trade_recommendations SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s WHERE id = %s""", (executed_at, execution_result, trade_id, recommendation_id) ) @staticmethod def mark_expired(recommendation_id): """标记推荐已过期""" db.execute_update( "UPDATE trade_recommendations SET status = 'expired' WHERE id = %s", (recommendation_id,) ) @staticmethod def mark_cancelled(recommendation_id, notes=None): """标记推荐已取消""" db.execute_update( "UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s", (notes, recommendation_id) ) @staticmethod def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None): """获取推荐记录""" query = "SELECT * FROM trade_recommendations WHERE 1=1" params = [] if status: query += " AND status = %s" params.append(status) if direction: query += " AND direction = %s" params.append(direction) if start_date: query += " AND recommendation_time >= %s" params.append(start_date) if end_date: query += " AND recommendation_time <= %s" params.append(end_date) query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s" params.append(limit) return db.execute_query(query, params) @staticmethod def get_active(): """获取当前有效的推荐(未过期、未执行、未取消) 同一交易对只返回最新的推荐(去重) """ return db.execute_query( """SELECT t1.* FROM trade_recommendations t1 INNER JOIN ( SELECT symbol, MAX(recommendation_time) as max_time FROM trade_recommendations WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW()) GROUP BY symbol ) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW()) ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC""" ) @staticmethod def get_by_id(recommendation_id): """根据ID获取推荐""" return db.execute_one( "SELECT * FROM trade_recommendations WHERE id = %s", (recommendation_id,) ) @staticmethod def get_by_symbol(symbol, limit=10): """根据交易对获取推荐记录""" return db.execute_query( """SELECT * FROM trade_recommendations WHERE symbol = %s ORDER BY recommendation_time DESC LIMIT %s""", (symbol, limit) )