This commit is contained in:
薇薇安 2026-02-14 11:34:37 +08:00
parent dab0981935
commit 4b6d73a5c4
2 changed files with 193 additions and 2 deletions

189
fix_trade_records.py Normal file
View File

@ -0,0 +1,189 @@
import asyncio
import logging
import os
import sys
from datetime import datetime, timedelta
# Add project root to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend'))
from backend.database.connection import db
from backend.database.models import Trade, Account
from trading_system.binance_client import BinanceClient
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fix_trades")
async def main():
# Loop through active accounts
# Based on previous check, active accounts are 2, 3, 4
active_account_ids = [2, 3, 4]
# Check columns once
existing_columns = set()
try:
cols = db.execute_query("DESCRIBE trades")
existing_columns = {row['Field'] for row in cols}
except Exception as e:
logger.error(f"Failed to describe trades table: {e}")
return
if 'realized_pnl' not in existing_columns:
logger.info("Adding 'realized_pnl' column to trades table...")
db.execute_update("ALTER TABLE trades ADD COLUMN realized_pnl DECIMAL(20, 8) NULL COMMENT '已实现盈亏'")
if 'commission' not in existing_columns:
logger.info("Adding 'commission' column to trades table...")
db.execute_update("ALTER TABLE trades ADD COLUMN commission DECIMAL(20, 8) NULL COMMENT '手续费'")
total_fixed = 0
for account_id in active_account_ids:
logger.info(f"Processing Account ID: {account_id}")
# Get account credentials
creds = Account.get_credentials(account_id)
if not creds:
logger.error(f"No account credentials found for account {account_id}")
continue
api_key, api_secret, use_testnet, status = creds
if not api_key or not api_secret:
logger.warning(f"Skipping account {account_id}: No API key/secret")
continue
if status != 'active':
logger.warning(f"Skipping account {account_id}: Status is {status}")
continue
client = BinanceClient(api_key, api_secret, testnet=use_testnet)
try:
# Check for proxy in environment
proxy = os.environ.get('HTTP_PROXY') or os.environ.get('HTTPS_PROXY')
requests_params = {'proxy': proxy} if proxy else None
await client.connect(requests_params=requests_params)
except Exception as e:
logger.error(f"Failed to connect to Binance for account {account_id}: {e}")
continue
try:
# Get recent closed trades from DB (last 30 days) for this account
thirty_days_ago = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d %H:%M:%S')
# Check if entry_time is int (unix timestamp) or string/datetime based on schema check
# Schema says entry_time is 'int unsigned', so it's a timestamp.
thirty_days_ago_ts = int((datetime.now() - timedelta(days=30)).timestamp())
query = """
SELECT * FROM trades
WHERE status = 'closed'
AND account_id = %s
AND entry_time > %s
ORDER BY id DESC
"""
trades = db.execute_query(query, (account_id, thirty_days_ago_ts))
logger.info(f"Found {len(trades)} closed trades for account {account_id} from last 30 days.")
updated_count = 0
for trade in trades:
symbol = trade['symbol']
trade_id = trade['id']
entry_time = trade['entry_time'] # Should be int
side = trade['side']
entry_ts_ms = entry_time * 1000
try:
# Get recent trades from Binance
recent_trades = await client.get_recent_trades(symbol, limit=50)
# Filter trades after entry time
closing_trades = [
t for t in recent_trades
if t.get('time', 0) > entry_ts_ms and float(t.get('realizedPnl', 0)) != 0
]
if not closing_trades:
continue
# Calculate actual values
total_pnl = 0.0
total_comm = 0.0
total_qty = 0.0
total_val = 0.0
for t in closing_trades:
pnl_val = float(t.get('realizedPnl', 0))
comm_val = float(t.get('commission', 0))
qty_val = float(t.get('qty', 0))
price_val = float(t.get('price', 0))
total_pnl += pnl_val
total_comm += comm_val
total_qty += qty_val
total_val += qty_val * price_val
if total_qty == 0:
continue
avg_exit_price = total_val / total_qty
# Check if values differ significantly from DB
db_pnl = float(trade.get('pnl') or 0)
db_exit_price = float(trade.get('exit_price') or 0)
needs_update = False
if abs(db_pnl - total_pnl) > 0.01:
needs_update = True
if 'realized_pnl' not in trade or trade.get('realized_pnl') is None:
needs_update = True
if needs_update:
logger.info(f"Fixing trade {trade_id} ({symbol}): PnL {db_pnl:.4f} -> {total_pnl:.4f}, ExitPrice {db_exit_price:.4f} -> {avg_exit_price:.4f}")
# Recalculate pnl_percent based on entry price
entry_price = float(trade.get('entry_price', 1))
if entry_price == 0:
entry_price = 1
if side == 'BUY':
pnl_percent = ((avg_exit_price - entry_price) / entry_price) * 100
else:
pnl_percent = ((entry_price - avg_exit_price) / entry_price) * 100
# Update DB
update_sql = """
UPDATE trades
SET pnl = %s,
pnl_percent = %s,
exit_price = %s,
realized_pnl = %s,
commission = %s
WHERE id = %s
"""
db.execute_update(update_sql, (total_pnl, pnl_percent, avg_exit_price, total_pnl, total_comm, trade_id))
updated_count += 1
except Exception as e:
logger.error(f"Error processing trade {trade_id} ({symbol}): {e}")
logger.info(f"Account {account_id}: Fixed {updated_count} trades.")
total_fixed += updated_count
if client.client:
await client.client.close_connection()
except Exception as e:
logger.error(f"Error processing account {account_id}: {e}")
logger.info(f"Total fixed trades: {total_fixed}")
if __name__ == "__main__":
asyncio.run(main())

View File

@ -155,13 +155,14 @@ class BinanceClient:
# 注意redis_cache 已在 __init__ 中初始化,这里不需要再次初始化
async def connect(self, timeout: int = None, retries: int = None):
async def connect(self, timeout: int = None, retries: int = None, requests_params: Dict = None):
"""
连接币安API
Args:
timeout: 连接超时时间默认从config读取
retries: 重试次数默认从config读取
requests_params: 请求参数如代理配置例如 {'proxy': 'http://127.0.0.1:7890'}
"""
# 连接前刷新API密钥确保使用最新值支持热更新
# 但如果 API 密钥为空(只用于获取公开行情),则跳过
@ -185,7 +186,8 @@ class BinanceClient:
self.client = await AsyncClient.create(
api_key=self.api_key or None, # 空字符串转为 None
api_secret=self.api_secret or None,
testnet=self.testnet
testnet=self.testnet,
requests_params=requests_params
)
# 测试连接(带超时)