auto_trade_sys/backend/database/models.py
薇薇安 9b81832af2 feat(trades, database, frontend): 增强交易记录同步与展示功能
在 `trades.py` 中更新了 `include_sync` 参数的默认值为 `True`,以便于订单记录与币安一致,并添加了提示信息以指导用户如何补全缺失的订单号。在 `models.py` 中新增了 `get_trades_missing_entry_order_id` 方法,用于获取缺少 `entry_order_id` 的记录,确保在同步时能够补全数据。前端组件 `StatsDashboard.jsx` 和 `TradeList.jsx` 中相应调整了开仓时间的展示逻辑和无交易记录时的提示信息,提升了用户体验与数据准确性。
2026-02-20 12:17:01 +08:00

1457 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库模型定义
"""
from database.connection import db
from datetime import datetime, timezone, timedelta
import json
import logging
import os
# 北京时间时区UTC+8
BEIJING_TZ = timezone(timedelta(hours=8))
def get_beijing_time():
"""获取当前北京时间UTC+8的Unix时间戳"""
return int(datetime.now(BEIJING_TZ).timestamp())
logger = logging.getLogger(__name__)
def _resolve_default_account_id() -> int:
"""
默认账号ID
- trading_system 多进程:每个进程可通过 ATS_ACCOUNT_ID 指定自己的 account_id
- backend未传 account_id 时默认走 1兼容单账号
"""
for k in ("ATS_ACCOUNT_ID", "ACCOUNT_ID", "ATS_DEFAULT_ACCOUNT_ID"):
v = (os.getenv(k, "") or "").strip()
if v:
try:
return int(v)
except Exception:
continue
return 1
DEFAULT_ACCOUNT_ID = _resolve_default_account_id()
def _table_has_column(table: str, col: str) -> bool:
try:
db.execute_one(f"SELECT {col} FROM {table} LIMIT 1")
return True
except Exception:
return False
class Account:
"""
账号模型(多账号)
- API Key/Secret 建议加密存储在 accounts 表中,而不是 trading_config
"""
@staticmethod
def get(account_id: int):
import logging
logger = logging.getLogger(__name__)
logger.info(f"Account.get called with account_id={account_id}")
row = db.execute_one("SELECT * FROM accounts WHERE id = %s", (int(account_id),))
if row:
logger.info(f"Account.get: found account_id={account_id}, name={row.get('name', 'N/A')}, status={row.get('status', 'N/A')}")
else:
logger.warning(f"Account.get: account_id={account_id} not found in database")
return row
@staticmethod
def get_by_id(account_id: int):
return Account.get(account_id)
@staticmethod
def list_all():
return db.execute_query("SELECT id, name, status, created_at, updated_at, api_key_enc, api_secret_enc, use_testnet FROM accounts ORDER BY id ASC")
@staticmethod
def create(name: str, api_key: str = "", api_secret: str = "", use_testnet: bool = False, status: str = "active"):
from security.crypto import encrypt_str # 延迟导入,避免无依赖时直接崩
api_key_enc = encrypt_str(api_key or "")
api_secret_enc = encrypt_str(api_secret or "")
db.execute_update(
"""INSERT INTO accounts (name, status, api_key_enc, api_secret_enc, use_testnet)
VALUES (%s, %s, %s, %s, %s)""",
(name, status, api_key_enc, api_secret_enc, bool(use_testnet)),
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
@staticmethod
def update(account_id: int, name: str = None, status: str = None, testnet: int = None, api_key: str = None, api_secret: str = None):
"""通用更新方法"""
# 如果涉及敏感字段,调用 update_credentials
if api_key is not None or api_secret is not None:
Account.update_credentials(account_id, api_key, api_secret, bool(testnet) if testnet is not None else None)
# 更新普通字段
fields = []
values = []
if name is not None:
fields.append("name = %s")
values.append(name)
if status is not None:
fields.append("status = %s")
values.append(status)
# 如果只有 testnet 而没有 key/secret也需要更新
if testnet is not None and api_key is None and api_secret is None:
fields.append("use_testnet = %s")
values.append(bool(testnet))
if fields:
values.append(int(account_id))
db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(values))
@staticmethod
def update_credentials(account_id: int, api_key: str = None, api_secret: str = None, use_testnet: bool = None):
from security.crypto import encrypt_str # 延迟导入
fields = []
params = []
if api_key is not None:
fields.append("api_key_enc = %s")
params.append(encrypt_str(api_key))
if api_secret is not None:
fields.append("api_secret_enc = %s")
params.append(encrypt_str(api_secret))
if use_testnet is not None:
fields.append("use_testnet = %s")
params.append(bool(use_testnet))
if not fields:
return
params.append(int(account_id))
db.execute_update(f"UPDATE accounts SET {', '.join(fields)} WHERE id = %s", tuple(params))
@staticmethod
def get_credentials(account_id: int):
"""
返回 (api_key, api_secret, use_testnet, status);密文字段会自动解密。
若未配置 master key 且库里是明文,仍可工作(但不安全)。
"""
import logging
logger = logging.getLogger(__name__)
logger.info(f"Account.get_credentials called with account_id={account_id}")
row = Account.get(account_id)
if not row:
logger.warning(f"Account.get_credentials: account_id={account_id} not found in database")
return "", "", False, "disabled"
try:
from security.crypto import decrypt_str
status = row.get("status") or "active"
api_key = decrypt_str(row.get("api_key_enc") or "")
api_secret = decrypt_str(row.get("api_secret_enc") or "")
except Exception:
# 兼容:无 cryptography 或未配 master key 时:
# - 若库里是明文,仍可工作
# - 若库里是 enc:v1 密文但未配 ATS_MASTER_KEY则不能解密也不能把密文当作 Key 使用
status = "disabled"
api_key_raw = row.get("api_key_enc") or ""
api_secret_raw = row.get("api_secret_enc") or ""
api_key = "" if str(api_key_raw).startswith("enc:v1:") else str(api_key_raw)
api_secret = "" if str(api_secret_raw).startswith("enc:v1:") else str(api_secret_raw)
use_testnet = bool(row.get("use_testnet") or False)
return api_key, api_secret, use_testnet, status
class User:
"""登录用户(管理员/普通用户)"""
@staticmethod
def get_by_username(username: str):
return db.execute_one("SELECT * FROM users WHERE username = %s", (str(username),))
@staticmethod
def get_by_id(user_id: int):
return db.execute_one("SELECT * FROM users WHERE id = %s", (int(user_id),))
@staticmethod
def list_all():
return db.execute_query("SELECT id, username, role, status, created_at, updated_at FROM users ORDER BY id ASC")
@staticmethod
def create(username: str, password_hash: str, role: str = "user", status: str = "active"):
db.execute_update(
"INSERT INTO users (username, password_hash, role, status) VALUES (%s, %s, %s, %s)",
(username, password_hash, role, status),
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")["id"]
@staticmethod
def set_password(user_id: int, password_hash: str):
db.execute_update("UPDATE users SET password_hash = %s WHERE id = %s", (password_hash, int(user_id)))
@staticmethod
def set_status(user_id: int, status: str):
db.execute_update("UPDATE users SET status = %s WHERE id = %s", (status, int(user_id)))
@staticmethod
def set_role(user_id: int, role: str):
db.execute_update("UPDATE users SET role = %s WHERE id = %s", (role, int(user_id)))
class UserAccountMembership:
"""用户-交易账号授权关系"""
@staticmethod
def add(user_id: int, account_id: int, role: str = "viewer"):
db.execute_update(
"""INSERT INTO user_account_memberships (user_id, account_id, role)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE role = VALUES(role)""",
(int(user_id), int(account_id), role),
)
@staticmethod
def add_membership(user_id: int, account_id: int, role: str = "viewer"):
return UserAccountMembership.add(user_id, account_id, role)
@staticmethod
def remove(user_id: int, account_id: int):
db.execute_update(
"DELETE FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
@staticmethod
def clear_other_owners_for_account(account_id: int, keep_user_id: int):
"""每个账号仅允许一名 owner将本账号下其他用户的 owner 降为 viewer。"""
db.execute_update(
"UPDATE user_account_memberships SET role = 'viewer' WHERE account_id = %s AND user_id != %s AND role = 'owner'",
(int(account_id), int(keep_user_id)),
)
@staticmethod
def list_for_user(user_id: int):
return db.execute_query(
"SELECT * FROM user_account_memberships WHERE user_id = %s ORDER BY account_id ASC",
(int(user_id),),
)
@staticmethod
def get_user_accounts(user_id: int):
"""获取用户关联的账号列表(包含账号详情)"""
return db.execute_query(
"""
SELECT a.id, a.name, a.status, a.created_at, a.updated_at, a.api_key_enc, a.api_secret_enc, m.role
FROM accounts a
JOIN user_account_memberships m ON a.id = m.account_id
WHERE m.user_id = %s
ORDER BY a.id ASC
""",
(int(user_id),)
)
@staticmethod
def list_for_account(account_id: int):
return db.execute_query(
"SELECT * FROM user_account_memberships WHERE account_id = %s ORDER BY user_id ASC",
(int(account_id),),
)
@staticmethod
def has_access(user_id: int, account_id: int) -> bool:
row = db.execute_one(
"SELECT 1 as ok FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
return bool(row)
@staticmethod
def get_role(user_id: int, account_id: int) -> str:
row = db.execute_one(
"SELECT role FROM user_account_memberships WHERE user_id = %s AND account_id = %s",
(int(user_id), int(account_id)),
)
return (row.get("role") if isinstance(row, dict) else None) or ""
class TradingConfig:
"""交易配置模型"""
@staticmethod
def get_all(account_id: int = None):
"""获取所有配置"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
return db.execute_query(
"SELECT * FROM trading_config WHERE account_id = %s ORDER BY category, config_key",
(aid,),
)
return db.execute_query("SELECT * FROM trading_config ORDER BY category, config_key")
@staticmethod
def get(key, account_id: int = None):
"""获取单个配置"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
return db.execute_one(
"SELECT * FROM trading_config WHERE account_id = %s AND config_key = %s",
(aid, key),
)
return db.execute_one("SELECT * FROM trading_config WHERE config_key = %s", (key,))
@staticmethod
def set(key, value, config_type, category, description=None, account_id: int = None):
"""设置配置"""
value_str = TradingConfig._convert_to_string(value, config_type)
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trading_config", "account_id"):
db.execute_update(
"""INSERT INTO trading_config
(account_id, config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(aid, key, value_str, config_type, category, description),
)
else:
db.execute_update(
"""INSERT INTO trading_config
(config_key, config_value, config_type, category, description)
VALUES (%s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description),
)
@staticmethod
def get_value(key, default=None, account_id: int = None):
"""获取配置值(自动转换类型)"""
result = TradingConfig.get(key, account_id=account_id)
if result:
return TradingConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型"""
if config_type in ('number', 'float', 'int', 'integer'):
try:
return float(value) if '.' in str(value) else int(value)
except:
return 0
elif config_type in ('boolean', 'bool'):
return str(value).lower() in ('true', '1', 'yes', 'on')
elif config_type == 'json':
try:
return json.loads(value)
except:
return {}
return value
@staticmethod
def _convert_to_string(value, config_type):
"""转换值为字符串"""
if config_type == 'json':
return json.dumps(value, ensure_ascii=False)
return str(value)
# 缓存键常量(与 market_cache 表配合使用)
MARKET_CACHE_KEY_EXCHANGE_INFO = "exchange_info"
MARKET_CACHE_KEY_FUNDING_INFO = "funding_info"
class MarketCache:
"""
市场数据缓存:交易对信息(exchange_info)、资金费率规则(funding_info) 等较固定内容入库,
减少对币安 API 的调用。trading_system 可优先从 DB 读取,过期或缺失时再拉 API 并回写。
"""
@staticmethod
def _table_exists():
try:
db.execute_one("SELECT 1 FROM market_cache LIMIT 1")
return True
except Exception:
return False
@staticmethod
def get(cache_key: str):
"""获取缓存:返回 dict { cache_value: 解析后的对象, updated_at } 或 None。"""
if not MarketCache._table_exists():
return None
try:
row = db.execute_one(
"SELECT cache_value, updated_at FROM market_cache WHERE cache_key = %s",
(cache_key,),
)
if not row:
return None
raw = row.get("cache_value")
updated_at = row.get("updated_at")
if raw is None:
return None
try:
value = json.loads(raw) if isinstance(raw, str) else raw
except Exception:
value = raw
return {"cache_value": value, "updated_at": updated_at}
except Exception as e:
logger.debug("MarketCache.get %s: %s", cache_key, e)
return None
@staticmethod
def set(cache_key: str, value) -> bool:
"""写入缓存value 可为 dict/list将序列化为 JSON"""
if not MarketCache._table_exists():
return False
try:
payload = json.dumps(value, ensure_ascii=False) if not isinstance(value, str) else value
db.execute_update(
"""INSERT INTO market_cache (cache_key, cache_value)
VALUES (%s, %s)
ON DUPLICATE KEY UPDATE cache_value = VALUES(cache_value), updated_at = CURRENT_TIMESTAMP""",
(cache_key, payload),
)
return True
except Exception as e:
logger.warning("MarketCache.set %s: %s", cache_key, e)
return False
@staticmethod
def get_exchange_info(max_age_seconds: int = 86400):
"""
获取缓存的 exchange_info币安 GET /fapi/v1/exchangeInfo
若缓存存在且 updated_at 在 max_age_seconds 内,返回解析后的 exchange_info 字典;
否则返回 None调用方应拉 API 并调用 set_exchange_info 回写)。
"""
out = MarketCache.get(MARKET_CACHE_KEY_EXCHANGE_INFO)
if not out:
return None
updated_at = out.get("updated_at")
if max_age_seconds is not None and updated_at:
try:
from datetime import datetime, timezone, timedelta
if isinstance(updated_at, (int, float)):
utc_ts = float(updated_at)
else:
# datetime 转 UTC 时间戳
utc_ts = updated_at.timestamp() if hasattr(updated_at, "timestamp") else 0
age = datetime.now(timezone.utc).timestamp() - utc_ts
if age > max_age_seconds:
return None
except Exception:
pass
return out.get("cache_value")
@staticmethod
def set_exchange_info(data: dict) -> bool:
"""写入 exchange_info 到 market_cache。"""
return MarketCache.set(MARKET_CACHE_KEY_EXCHANGE_INFO, data)
@staticmethod
def get_funding_info(max_age_seconds: int = 86400):
"""
获取缓存的 funding_info币安 GET /fapi/v1/fundingInfo
若缓存存在且在 max_age_seconds 内,返回解析后的列表/字典;否则返回 None。
"""
out = MarketCache.get(MARKET_CACHE_KEY_FUNDING_INFO)
if not out:
return None
updated_at = out.get("updated_at")
if max_age_seconds is not None and updated_at:
try:
from datetime import datetime, timezone
if isinstance(updated_at, (int, float)):
utc_ts = float(updated_at)
else:
utc_ts = updated_at.timestamp() if hasattr(updated_at, "timestamp") else 0
age = datetime.now(timezone.utc).timestamp() - utc_ts
if age > max_age_seconds:
return None
except Exception:
pass
return out.get("cache_value")
@staticmethod
def set_funding_info(data) -> bool:
"""写入 funding_info 到 market_cache。"""
return MarketCache.set(MARKET_CACHE_KEY_FUNDING_INFO, data)
class GlobalStrategyConfig:
"""全局策略配置模型(独立于账户,管理员专用)"""
@staticmethod
def get_all():
"""获取所有全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return []
return db.execute_query(
"SELECT * FROM global_strategy_config ORDER BY category, config_key"
)
@staticmethod
def get(key):
"""获取单个全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return None
return db.execute_one(
"SELECT * FROM global_strategy_config WHERE config_key = %s",
(key,)
)
@staticmethod
def set(key, value, config_type, category, description=None, updated_by=None):
"""设置全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
# 表不存在时回退到trading_config兼容旧系统
return TradingConfig.set(key, value, config_type, category, description, account_id=1)
value_str = TradingConfig._convert_to_string(value, config_type)
db.execute_update(
"""INSERT INTO global_strategy_config
(config_key, config_value, config_type, category, description, updated_by)
VALUES (%s, %s, %s, %s, %s, %s)
ON DUPLICATE KEY UPDATE
config_value = VALUES(config_value),
config_type = VALUES(config_type),
category = VALUES(category),
description = VALUES(description),
updated_by = VALUES(updated_by),
updated_at = CURRENT_TIMESTAMP""",
(key, value_str, config_type, category, description, updated_by),
)
@staticmethod
def get_value(key, default=None):
"""获取全局配置值(自动转换类型)"""
result = GlobalStrategyConfig.get(key)
if result:
return GlobalStrategyConfig._convert_value(result['config_value'], result['config_type'])
return default
@staticmethod
def _convert_value(value, config_type):
"""转换配置值类型复用TradingConfig的逻辑"""
return TradingConfig._convert_value(value, config_type)
@staticmethod
def delete(key):
"""删除全局配置"""
if not _table_has_column("global_strategy_config", "config_key"):
return
db.execute_update(
"DELETE FROM global_strategy_config WHERE config_key = %s",
(key,)
)
class Trade:
"""交易记录模型"""
@staticmethod
def create(
symbol,
side,
quantity,
entry_price,
leverage=10,
entry_reason=None,
entry_order_id=None,
client_order_id=None,
stop_loss_price=None,
take_profit_price=None,
take_profit_1=None,
take_profit_2=None,
atr=None,
notional_usdt=None,
margin_usdt=None,
account_id: int = None,
entry_context=None,
status: str = "open",
entry_time=None,
):
"""创建交易记录(使用北京时间)
Args:
entry_time: 开仓时间Unix 时间戳秒,或 None 表示当前北京时间)。补建/同步时建议从币安订单或成交取真实时间。
symbol: 交易对
status: 状态,默认 "open";先落库等 WS 成交时可传 "pending"
side: 方向
quantity: 数量
entry_price: 入场价
leverage: 杠杆
entry_reason: 入场原因(简短文本)
entry_order_id: 币安开仓订单号(可选,用于对账)
client_order_id: 币安自定义订单号 clientOrderId可选用于在订单记录中核对系统单
stop_loss_price: 实际使用的止损价格考虑了ATR等动态计算
take_profit_price: 实际使用的止盈价格考虑了ATR等动态计算
take_profit_1: 第一目标止盈价(可选)
take_profit_2: 第二目标止盈价(可选)
atr: 开仓时使用的ATR值可选
notional_usdt: 名义下单量USDT可选
margin_usdt: 保证金USDT可选
entry_context: 入场思路/过程dict将存为 JSON信号强度、市场状态、趋势、过滤通过情况等便于事后分析策略执行效果
"""
if entry_time is not None:
try:
if hasattr(entry_time, "timestamp"):
entry_time = int(entry_time.timestamp())
else:
entry_time = int(float(entry_time))
except (TypeError, ValueError):
entry_time = get_beijing_time()
else:
entry_time = get_beijing_time()
# 自动计算 notional/margin若调用方没传
try:
if notional_usdt is None and quantity is not None and entry_price is not None:
notional_usdt = float(quantity) * float(entry_price)
except Exception:
pass
try:
if margin_usdt is None and notional_usdt is not None:
lv = float(leverage) if leverage else 0
margin_usdt = (float(notional_usdt) / lv) if lv and lv > 0 else float(notional_usdt)
except Exception:
pass
def _has_column(col: str) -> bool:
try:
db.execute_one(f"SELECT {col} FROM trades LIMIT 1")
return True
except Exception:
return False
# 动态构建 INSERT兼容不同schema
columns = ["symbol", "side", "quantity", "entry_price", "leverage", "entry_reason", "status", "entry_time"]
values = [symbol, side, quantity, entry_price, leverage, entry_reason, (status or "open"), entry_time]
# 在真正执行 INSERT 之前,利用 entry_order_id / client_order_id 做一次幂等去重
try:
if entry_order_id is not None and _has_column("entry_order_id"):
try:
existing = Trade.get_by_entry_order_id(entry_order_id)
except Exception:
existing = None
if existing:
try:
logger.debug(
f"Trade.create: entry_order_id={entry_order_id} 已存在 (id={existing.get('id')}, "
f"symbol={existing.get('symbol')}, status={existing.get('status')}),直接复用"
)
except Exception:
pass
return existing.get("id")
if client_order_id and _has_column("client_order_id"):
try:
existing = Trade.get_by_client_order_id(client_order_id)
except Exception:
existing = None
if existing:
try:
logger.debug(
f"Trade.create: client_order_id={client_order_id!r} 已存在 (id={existing.get('id')}, "
f"symbol={existing.get('symbol')}, status={existing.get('status')}),直接复用"
)
except Exception:
pass
return existing.get("id")
except Exception:
# 去重失败不影响后续正常插入,由数据库唯一约束兜底
pass
if _has_column("account_id"):
columns.insert(0, "account_id")
values.insert(0, int(account_id or DEFAULT_ACCOUNT_ID))
if _has_column("entry_order_id"):
columns.append("entry_order_id")
values.append(entry_order_id)
if _has_column("client_order_id"):
columns.append("client_order_id")
values.append(client_order_id)
if _has_column("notional_usdt"):
columns.append("notional_usdt")
values.append(notional_usdt)
if _has_column("margin_usdt"):
columns.append("margin_usdt")
values.append(margin_usdt)
if _has_column("atr"):
columns.append("atr")
values.append(atr)
if _has_column("stop_loss_price"):
columns.append("stop_loss_price")
values.append(stop_loss_price)
if _has_column("take_profit_price"):
columns.append("take_profit_price")
values.append(take_profit_price)
if _has_column("take_profit_1"):
columns.append("take_profit_1")
values.append(take_profit_1)
if _has_column("take_profit_2"):
columns.append("take_profit_2")
values.append(take_profit_2)
if _has_column("entry_context") and entry_context is not None:
try:
entry_context_str = json.dumps(entry_context, ensure_ascii=False) if isinstance(entry_context, dict) else str(entry_context)
except Exception:
entry_context_str = None
if entry_context_str is not None:
columns.append("entry_context")
values.append(entry_context_str)
placeholders = ", ".join(["%s"] * len(columns))
sql = f"INSERT INTO trades ({', '.join(columns)}) VALUES ({placeholders})"
db.execute_update(sql, tuple(values))
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def update_exit(
trade_id,
exit_price,
exit_reason,
pnl,
pnl_percent,
exit_order_id=None,
strategy_type=None,
duration_minutes=None,
exit_time_ts=None,
realized_pnl=None,
commission=None,
commission_asset=None,
):
"""更新平仓信息(使用北京时间)
Args:
trade_id: 交易记录ID
exit_price: 出场价
exit_reason: 平仓原因
pnl: 盈亏
pnl_percent: 盈亏百分比
exit_order_id: 币安平仓订单号(可选,用于对账)
realized_pnl: 币安实际结算盈亏(可选)
commission: 交易手续费(可选)
commission_asset: 手续费币种(可选)
注意:如果 exit_order_id 已存在且属于其他交易记录,会跳过更新 exit_order_id 以避免唯一约束冲突
"""
# exit_time_ts: 允许外部传入“真实成交时间”Unix秒以便统计持仓时长更准确
try:
exit_time = int(exit_time_ts) if exit_time_ts is not None else get_beijing_time()
except Exception:
exit_time = get_beijing_time()
# 准备额外字段更新 helper
def _append_extra_fields(fields, values):
if strategy_type is not None:
fields.append("strategy_type = %s")
values.append(strategy_type)
if duration_minutes is not None:
fields.append("duration_minutes = %s")
values.append(duration_minutes)
# 新增字段(检查是否存在)
if realized_pnl is not None and _table_has_column("trades", "realized_pnl"):
fields.append("realized_pnl = %s")
values.append(realized_pnl)
if commission is not None and _table_has_column("trades", "commission"):
fields.append("commission = %s")
values.append(commission)
if commission_asset is not None and _table_has_column("trades", "commission_asset"):
fields.append("commission_asset = %s")
values.append(commission_asset)
# 如果提供了 exit_order_id先检查是否已被其他交易记录使用
if exit_order_id is not None:
try:
existing_trade = Trade.get_by_exit_order_id(exit_order_id)
if existing_trade:
if existing_trade['id'] == trade_id:
# exit_order_id 属于当前交易记录:允许继续更新
logger.debug(
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已存在,将继续更新其他字段"
)
else:
# 如果 exit_order_id 已被其他交易记录使用,记录警告但不更新 exit_order_id
logger.warning(
f"交易记录 {trade_id} 的 exit_order_id {exit_order_id} 已被交易记录 {existing_trade['id']} 使用,"
f"跳过更新 exit_order_id只更新其他字段"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
_append_extra_fields(update_fields, update_values)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
return
except Exception as e:
# 如果查询失败,记录警告但继续正常更新
logger.warning(f"检查 exit_order_id {exit_order_id} 时出错: {e},继续正常更新")
# 正常更新(包括 exit_order_id
try:
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'",
"exit_order_id = %s"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent, exit_order_id]
_append_extra_fields(update_fields, update_values)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
except Exception as e:
# 如果更新失败(可能是唯一约束冲突),尝试不更新 exit_order_id
error_str = str(e)
if "Duplicate entry" in error_str and "exit_order_id" in error_str:
logger.warning(
f"更新交易记录 {trade_id} 时 exit_order_id {exit_order_id} 唯一约束冲突,"
f"尝试不更新 exit_order_id"
)
# 只更新其他字段,不更新 exit_order_id
update_fields = [
"exit_price = %s", "exit_time = %s",
"exit_reason = %s", "pnl = %s", "pnl_percent = %s", "status = 'closed'"
]
update_values = [exit_price, exit_time, exit_reason, pnl, pnl_percent]
if strategy_type is not None:
update_fields.append("strategy_type = %s")
update_values.append(strategy_type)
if duration_minutes is not None:
update_fields.append("duration_minutes = %s")
update_values.append(duration_minutes)
update_values.append(trade_id)
db.execute_update(
f"UPDATE trades SET {', '.join(update_fields)} WHERE id = %s",
tuple(update_values)
)
else:
# 其他错误,重新抛出
raise
@staticmethod
def update_open_fields(trade_id: int, **kwargs):
"""开仓后完善字段:止损/止盈/名义/保证金/entry_context 等(用于先 pending 落库、成交后补全)。"""
if not kwargs:
return
allowed = {
"stop_loss_price", "take_profit_price", "take_profit_1", "take_profit_2",
"notional_usdt", "margin_usdt", "entry_context", "atr"
}
updates = []
values = []
for k, v in kwargs.items():
if k not in allowed or v is None:
continue
if k == "entry_context":
try:
v = json.dumps(v, ensure_ascii=False) if isinstance(v, dict) else str(v)
except Exception:
continue
updates.append(f"{k} = %s")
values.append(v)
if not updates:
return
values.append(trade_id)
try:
db.execute_update(
f"UPDATE trades SET {', '.join(updates)} WHERE id = %s",
tuple(values)
)
except Exception as e:
logger.warning(f"update_open_fields trade_id={trade_id} 失败: {e}")
@staticmethod
def get_by_entry_order_id(entry_order_id):
"""根据开仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE entry_order_id = %s",
(entry_order_id,)
)
@staticmethod
def get_by_exit_order_id(exit_order_id):
"""根据平仓订单号获取交易记录"""
return db.execute_one(
"SELECT * FROM trades WHERE exit_order_id = %s",
(exit_order_id,)
)
@staticmethod
def get_trades_missing_entry_order_id(symbol: str, account_id: int, limit: int = 50):
"""
获取该 symbol+account 下缺少 entry_order_id 的记录(用于从币安同步订单时补全)。
包含 entry_time 为 NULL 的旧记录,避免「仅可对账」下看不到今日有单却未补全的情况。
"""
if not symbol or account_id is None:
return []
try:
if not _table_has_column("trades", "entry_order_id"):
return []
if _table_has_column("trades", "account_id"):
return db.execute_query(
"""SELECT * FROM trades
WHERE account_id = %s AND symbol = %s
AND (entry_order_id IS NULL OR entry_order_id = 0 OR entry_order_id = '')
ORDER BY id DESC
LIMIT %s""",
(int(account_id), symbol.strip(), int(limit)),
)
return db.execute_query(
"""SELECT * FROM trades
WHERE symbol = %s
AND (entry_order_id IS NULL OR entry_order_id = 0 OR entry_order_id = '')
ORDER BY id DESC
LIMIT %s""",
(symbol.strip(), int(limit)),
)
except Exception as e:
logger.debug(f"get_trades_missing_entry_order_id 失败: {e}")
return []
@staticmethod
def get_by_client_order_id(client_order_id, account_id: int = None):
"""根据 clientOrderId 获取交易记录(可选按 account_id 隔离)"""
if not client_order_id:
return None
try:
if account_id is not None and _table_has_column("trades", "account_id"):
return db.execute_one(
"SELECT * FROM trades WHERE client_order_id = %s AND account_id = %s",
(client_order_id, int(account_id))
)
return db.execute_one(
"SELECT * FROM trades WHERE client_order_id = %s",
(client_order_id,)
)
except Exception:
return None
@staticmethod
def update_pending_to_filled(client_order_id, account_id: int, entry_order_id, entry_price: float, quantity: float):
"""WS 或 REST 收到成交后:按 client_order_id 将 pending 记录完善为已成交(幂等)"""
if not client_order_id or account_id is None:
return False
try:
row = Trade.get_by_client_order_id(client_order_id, account_id)
if not row:
return False
# 仅当仍为 pending 或 entry_order_id 为空时更新,避免覆盖已完善数据
if row.get("status") not in ("pending", "open") and row.get("entry_order_id"):
return True # 已完善,视为成功
db.execute_update(
"""UPDATE trades SET entry_order_id = %s, entry_price = %s, quantity = %s, status = 'open'
WHERE client_order_id = %s AND account_id = %s""",
(entry_order_id, float(entry_price), float(quantity), client_order_id, int(account_id))
)
return True
except Exception as e:
logger.warning(f"update_pending_to_filled 失败 client_order_id={client_order_id!r}: {e}")
return False
@staticmethod
def update_pending_by_entry_order_id(symbol: str, account_id: int, entry_order_id, entry_price: float, quantity: float):
"""
UDS 开仓 FILLED 但无 clientOrderId 时的兜底:用 orderId 完善一条 pending 记录。
若 DB 中已有该 entry_order_id 则跳过;否则在该 symbol+account 下找「status 为 pending 且 entry_order_id 为空」的记录,
若恰好 1 条则用 orderId/价格/数量更新(避免误匹配多笔 pending 时只更新一条)。
"""
if not symbol or account_id is None or entry_order_id is None:
return False
try:
existing = Trade.get_by_entry_order_id(entry_order_id)
if existing:
return True # 已存在该订单号,无需兜底
if not _table_has_column("trades", "account_id"):
return False
# 查该 symbol+account 下 pending 且无 entry_order_id 的记录
rows = db.execute_query(
"""SELECT id FROM trades
WHERE account_id = %s AND symbol = %s AND status = 'pending'
AND (entry_order_id IS NULL OR entry_order_id = 0)
ORDER BY id DESC""",
(int(account_id), symbol.strip())
)
if not rows or len(rows) != 1:
return False
tid = rows[0]["id"]
db.execute_update(
"""UPDATE trades SET entry_order_id = %s, entry_price = %s, quantity = %s, status = 'open'
WHERE id = %s""",
(str(entry_order_id), float(entry_price), float(quantity), int(tid))
)
logger.info(f"Trade.update_pending_by_entry_order_id: 已用 orderId={entry_order_id} 兜底完善 pending 记录 id={tid} symbol={symbol!r}")
return True
except Exception as e:
logger.warning(f"update_pending_by_entry_order_id 失败 symbol={symbol!r} orderId={entry_order_id}: {e}")
return False
@staticmethod
def update_entry_order_id(trade_id: int, entry_order_id):
"""补全或更新开仓订单号(用于 REST 兜底补全)"""
if not trade_id or not entry_order_id:
return False
try:
db.execute_update(
"""UPDATE trades SET entry_order_id = %s WHERE id = %s AND (entry_order_id IS NULL OR entry_order_id = '')""",
(str(entry_order_id), int(trade_id))
)
return True
except Exception as e:
logger.warning(f"update_entry_order_id 失败 trade_id={trade_id}: {e}")
return False
@staticmethod
def get_all(start_timestamp=None, end_timestamp=None, symbol=None, status=None, trade_type=None, exit_reason=None, account_id: int = None, time_filter: str = "exit", limit: int = None, reconciled_only: bool = False, include_sync: bool = True):
"""
获取交易记录。
time_filter: 时间范围按哪种时间筛选
- "exit": 按平仓时间(已平仓用 exit_time未平仓用 entry_time。选「今天」= 今天平掉的单 + 今天开的未平仓,更符合直觉。
- "entry": 按开仓时间。
- "both": 原逻辑COALESCE(exit_time, entry_time)。
limit: 最多返回条数None 表示不限制。
reconciled_only: 仅可对账(有 entry_order_id已平仓的还有 exit_order_id在 SQL 中过滤以减轻负载。
include_sync: 是否包含 entry_reason=sync_recovered / exit_reason=sync 的记录,在 SQL 中过滤。
"""
query = "SELECT * FROM trades WHERE 1=1"
params = []
# 多账号隔离兼容旧schema
try:
if _table_has_column("trades", "account_id"):
query += " AND account_id = %s"
params.append(int(account_id or DEFAULT_ACCOUNT_ID))
except Exception:
pass
if reconciled_only and _table_has_column("trades", "entry_order_id"):
query += " AND entry_order_id IS NOT NULL AND entry_order_id != 0"
query += " AND (status != 'closed' OR (exit_order_id IS NOT NULL AND exit_order_id != 0))"
if include_sync is False:
query += " AND (entry_reason IS NULL OR entry_reason != 'sync_recovered')"
query += " AND (exit_reason IS NULL OR exit_reason != 'sync')"
if start_timestamp is not None and end_timestamp is not None:
if time_filter == "exit":
query += " AND ((status = 'closed' AND exit_time >= %s AND exit_time <= %s) OR (status != 'closed' AND entry_time >= %s AND entry_time <= %s))"
params.extend([start_timestamp, end_timestamp, start_timestamp, end_timestamp])
elif time_filter == "entry":
query += " AND entry_time >= %s AND entry_time <= %s"
params.extend([start_timestamp, end_timestamp])
else:
query += " AND COALESCE(exit_time, entry_time) >= %s AND COALESCE(exit_time, entry_time) <= %s"
params.extend([start_timestamp, end_timestamp])
elif start_timestamp is not None:
if time_filter == "exit":
query += " AND ((status = 'closed' AND exit_time >= %s) OR (status != 'closed' AND entry_time >= %s))"
params.extend([start_timestamp, start_timestamp])
elif time_filter == "entry":
query += " AND entry_time >= %s"
params.append(start_timestamp)
else:
query += " AND COALESCE(exit_time, entry_time) >= %s"
params.append(start_timestamp)
elif end_timestamp is not None:
if time_filter == "exit":
query += " AND ((status = 'closed' AND exit_time <= %s) OR (status != 'closed' AND entry_time <= %s))"
params.extend([end_timestamp, end_timestamp])
elif time_filter == "entry":
query += " AND entry_time <= %s"
params.append(end_timestamp)
else:
query += " AND COALESCE(exit_time, entry_time) <= %s"
params.append(end_timestamp)
if symbol:
query += " AND symbol = %s"
params.append(symbol)
if status:
query += " AND status = %s"
params.append(status)
if trade_type:
query += " AND side = %s"
params.append(trade_type)
if exit_reason:
query += " AND exit_reason = %s"
params.append(exit_reason)
query += " ORDER BY COALESCE(exit_time, entry_time) DESC, id DESC"
if limit is not None and limit > 0:
query += " LIMIT %s"
params.append(int(limit))
logger.debug(f"查询交易记录: time_filter={time_filter}, limit={limit}, reconciled_only={reconciled_only}, include_sync={include_sync}")
result = db.execute_query(query, params)
return result
@staticmethod
def get_by_symbol(symbol, status='open', account_id: int = None):
"""根据交易对获取持仓"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("trades", "account_id"):
return db.execute_query(
"SELECT * FROM trades WHERE account_id = %s AND symbol = %s AND status = %s",
(aid, symbol, status),
)
return db.execute_query(
"SELECT * FROM trades WHERE symbol = %s AND status = %s",
(symbol, status),
)
@staticmethod
def set_exit_order_id_for_open_trade(symbol: str, account_id: int, exit_order_id, entry_order_id: int = None) -> bool:
"""
ALGO_UPDATE/条件单触发后:为指定 symbol 下未填 exit_order_id 的 open 记录补全平仓订单号。
优先按 entry_order_id 精确匹配,若无则按 symbol 匹配最早的一条 open 记录。
"""
if not symbol or account_id is None or exit_order_id is None:
return False
try:
if not _table_has_column("trades", "account_id"):
return False
# 优先按 entry_order_id 精确匹配(如果提供了 entry_order_id
if entry_order_id:
n = db.execute_update(
"""UPDATE trades SET exit_order_id = %s
WHERE account_id = %s AND symbol = %s AND status = 'open'
AND entry_order_id = %s
AND (exit_order_id IS NULL OR exit_order_id = '')
LIMIT 1""",
(str(exit_order_id), int(account_id), symbol.strip(), str(entry_order_id))
)
if n and n > 0:
logger.debug(f"set_exit_order_id_for_open_trade: 按 entry_order_id={entry_order_id} 精确匹配成功")
return True
# 否则按 symbol 匹配最早的一条 open 记录(按 entry_time 排序)
n = db.execute_update(
"""UPDATE trades SET exit_order_id = %s
WHERE account_id = %s AND symbol = %s AND status = 'open'
AND (exit_order_id IS NULL OR exit_order_id = '')
ORDER BY entry_time ASC
LIMIT 1""",
(str(exit_order_id), int(account_id), symbol.strip())
)
return n is not None and n > 0
except Exception as e:
logger.warning(f"set_exit_order_id_for_open_trade 失败 symbol={symbol!r}: {e}")
return False
class AccountSnapshot:
"""账户快照模型"""
@staticmethod
def create(total_balance, available_balance, total_position_value, total_pnl, open_positions, account_id: int = None):
"""创建账户快照(使用北京时间)"""
snapshot_time = get_beijing_time()
if _table_has_column("account_snapshots", "account_id"):
db.execute_update(
"""INSERT INTO account_snapshots
(account_id, total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s, %s)""",
(int(account_id or DEFAULT_ACCOUNT_ID), total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
)
else:
db.execute_update(
"""INSERT INTO account_snapshots
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time)
VALUES (%s, %s, %s, %s, %s, %s)""",
(total_balance, available_balance, total_position_value, total_pnl, open_positions, snapshot_time),
)
@staticmethod
def get_recent(days=7, account_id: int = None):
"""获取最近的快照"""
aid = int(account_id or DEFAULT_ACCOUNT_ID)
if _table_has_column("account_snapshots", "account_id"):
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE account_id = %s AND snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(aid, days),
)
return db.execute_query(
"""SELECT * FROM account_snapshots
WHERE snapshot_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY snapshot_time DESC""",
(days,),
)
class MarketScan:
"""市场扫描记录模型"""
@staticmethod
def create(symbols_scanned, symbols_found, top_symbols, scan_duration):
"""创建扫描记录(使用北京时间)"""
scan_time = get_beijing_time()
db.execute_update(
"""INSERT INTO market_scans
(symbols_scanned, symbols_found, top_symbols, scan_duration, scan_time)
VALUES (%s, %s, %s, %s, %s)""",
(symbols_scanned, symbols_found, json.dumps(top_symbols), scan_duration, scan_time)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的扫描记录"""
return db.execute_query(
"SELECT * FROM market_scans ORDER BY scan_time DESC LIMIT %s",
(limit,)
)
class TradingSignal:
"""交易信号模型"""
@staticmethod
def create(symbol, signal_direction, signal_strength, signal_reason,
rsi=None, macd_histogram=None, market_regime=None):
"""创建交易信号(使用北京时间)"""
signal_time = get_beijing_time()
db.execute_update(
"""INSERT INTO trading_signals
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, signal_direction, signal_strength, signal_reason,
rsi, macd_histogram, market_regime, signal_time)
)
@staticmethod
def mark_executed(signal_id):
"""标记信号已执行"""
db.execute_update(
"UPDATE trading_signals SET executed = TRUE WHERE id = %s",
(signal_id,)
)
@staticmethod
def get_recent(limit=100):
"""获取最近的信号"""
return db.execute_query(
"SELECT * FROM trading_signals ORDER BY signal_time DESC LIMIT %s",
(limit,)
)
class TradeRecommendation:
"""推荐交易对模型"""
@staticmethod
def create(
symbol, direction, current_price, change_percent, recommendation_reason,
signal_strength, market_regime=None, trend_4h=None,
rsi=None, macd_histogram=None,
bollinger_upper=None, bollinger_middle=None, bollinger_lower=None,
ema20=None, ema50=None, ema20_4h=None, atr=None,
suggested_stop_loss=None, suggested_take_profit_1=None, suggested_take_profit_2=None,
suggested_position_percent=None, suggested_leverage=10,
order_type='LIMIT', suggested_limit_price=None,
volume_24h=None, volatility=None, notes=None,
user_guide=None, recommendation_category=None, risk_level=None,
expected_hold_time=None, trading_tutorial=None, max_hold_days=3
):
"""创建推荐记录(使用北京时间)"""
recommendation_time = get_beijing_time()
# 默认24小时后过期
expires_at = recommendation_time + timedelta(hours=24)
# 检查字段是否存在兼容旧数据库schema
try:
db.execute_one("SELECT order_type FROM trade_recommendations LIMIT 1")
has_order_fields = True
except:
has_order_fields = False
# 检查是否有 user_guide 字段
try:
db.execute_one("SELECT user_guide FROM trade_recommendations LIMIT 1")
has_user_guide_fields = True
except:
has_user_guide_fields = False
if has_user_guide_fields:
# 包含所有新字段user_guide等
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes,
user_guide, recommendation_category, risk_level,
expected_hold_time, trading_tutorial, max_hold_days)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes,
user_guide, recommendation_category, risk_level,
expected_hold_time, trading_tutorial, max_hold_days)
)
elif has_order_fields:
# 只有 order_type 字段,没有 user_guide 字段
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
order_type, suggested_limit_price,
volume_24h, volatility, expires_at, notes)
)
else:
# 兼容旧schema
db.execute_update(
"""INSERT INTO trade_recommendations
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)""",
(symbol, direction, recommendation_time, current_price, change_percent,
recommendation_reason, signal_strength, market_regime, trend_4h,
rsi, macd_histogram, bollinger_upper, bollinger_middle, bollinger_lower,
ema20, ema50, ema20_4h, atr,
suggested_stop_loss, suggested_take_profit_1, suggested_take_profit_2,
suggested_position_percent, suggested_leverage,
volume_24h, volatility, expires_at, notes)
)
return db.execute_one("SELECT LAST_INSERT_ID() as id")['id']
@staticmethod
def mark_executed(recommendation_id, trade_id=None, execution_result='success'):
"""标记推荐已执行"""
executed_at = get_beijing_time()
db.execute_update(
"""UPDATE trade_recommendations
SET status = 'executed', executed_at = %s, execution_result = %s, execution_trade_id = %s
WHERE id = %s""",
(executed_at, execution_result, trade_id, recommendation_id)
)
@staticmethod
def mark_expired(recommendation_id):
"""标记推荐已过期"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'expired' WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def mark_cancelled(recommendation_id, notes=None):
"""标记推荐已取消"""
db.execute_update(
"UPDATE trade_recommendations SET status = 'cancelled', notes = %s WHERE id = %s",
(notes, recommendation_id)
)
@staticmethod
def get_all(status=None, direction=None, limit=100, start_date=None, end_date=None):
"""获取推荐记录"""
query = "SELECT * FROM trade_recommendations WHERE 1=1"
params = []
if status:
query += " AND status = %s"
params.append(status)
if direction:
query += " AND direction = %s"
params.append(direction)
if start_date:
query += " AND recommendation_time >= %s"
params.append(start_date)
if end_date:
query += " AND recommendation_time <= %s"
params.append(end_date)
query += " ORDER BY recommendation_time DESC, signal_strength DESC LIMIT %s"
params.append(limit)
return db.execute_query(query, params)
@staticmethod
def get_active():
"""获取当前有效的推荐(未过期、未执行、未取消)
同一交易对只返回最新的推荐(去重)
"""
return db.execute_query(
"""SELECT t1.* FROM trade_recommendations t1
INNER JOIN (
SELECT symbol, MAX(recommendation_time) as max_time
FROM trade_recommendations
WHERE status = 'active' AND (expires_at IS NULL OR expires_at > NOW())
GROUP BY symbol
) t2 ON t1.symbol = t2.symbol AND t1.recommendation_time = t2.max_time
WHERE t1.status = 'active' AND (t1.expires_at IS NULL OR t1.expires_at > NOW())
ORDER BY t1.signal_strength DESC, t1.recommendation_time DESC"""
)
@staticmethod
def get_by_id(recommendation_id):
"""根据ID获取推荐"""
return db.execute_one(
"SELECT * FROM trade_recommendations WHERE id = %s",
(recommendation_id,)
)
@staticmethod
def get_by_symbol(symbol, limit=10):
"""根据交易对获取推荐记录"""
return db.execute_query(
"""SELECT * FROM trade_recommendations
WHERE symbol = %s
ORDER BY recommendation_time DESC LIMIT %s""",
(symbol, limit)
)