diff --git a/docs/缓存优化_写入Valkey.md b/docs/缓存优化_写入Valkey.md new file mode 100644 index 0000000..b4237f6 --- /dev/null +++ b/docs/缓存优化_写入Valkey.md @@ -0,0 +1,71 @@ +# 缓存策略:全用 Redis,基本不占服务器内存 + +## 目标 + +- **全量缓存以 Redis/Valkey 为主**:有 Redis 时只读只写 Redis,进程内基本不保留缓存,减轻服务器内存压力。 +- **Redis 内容全部带过期时间**:所有写入 Redis 的键均设置 TTL,不在 Redis 内无限增长。 + +## 一、Redis 键与 TTL 统一配置 + +所有 TTL 与键前缀集中在 `trading_system/redis_ttl.py` 中定义,**禁止**在 Redis 中写入无过期时间的业务键。 + +| 用途 | 键/前缀示例 | TTL(秒) | 说明 | +|----------------|--------------------------|-----------|----------------| +| 持仓缓存 | `ats:positions:cache` | 300 | 5 分钟 | +| 余额缓存 | `ats:balance:cache:USDT` | 300 | 5 分钟 | +| K 线(WS 写入)| `market:kline:{s}:{i}` | 600 | 10 分钟 | +| 24h 行情 | `ticker_24h:{symbol}` | 30 | 单 symbol | +| 全市场 24h | `market:ticker_24h` | 120 | 2 分钟 | +| BookTicker | `market:book_ticker` | 30 | 30 秒 | +| 交易对信息 | `symbol_info:{symbol}` | 3600 | 1 小时 | +| listenKey 缓存 | `listen_key:*` | 3300 | 55 分钟 | +| 市场 WS Leader | `market_ws_leader` | 30 | 选主续期 | + +更多见 `trading_system/redis_ttl.py`(含 K 线按 interval 的 TTL 等)。 + +## 二、各模块行为 + +### 1. RedisCache 降级内存缓存 + +**文件**: `trading_system/redis_cache.py` + +- Redis 不可用时降级到进程内存。 +- 内存缓存**有上限**:最多 200 条;单条** 5 分钟**过期,过期或满时淘汰最久未用,避免无限增长。 + +### 2. K 线 + +**文件**: `trading_system/kline_stream.py` + +- Leader 进程:WebSocket 收到 K 线后写入 Redis(TTL 见 `redis_ttl.TTL_KLINE_STREAM`),**写入成功后从进程内 `_kline_cache` 删除该 key**,以 Redis 为准、基本不占服务器内存。 +- 非 Leader / 读路径:从 Redis 读(`get_klines_from_redis`);进程内仅保留未刷写 Redis 的少量缓冲。 + +### 3. 持仓 / 余额 + +**文件**: `trading_system/user_data_stream.py` + +- **有 Redis 时**:只写 Redis,**不写** `_position_updates_cache` / `_balance_updates_cache`;读时优先从 Redis 读。 +- **无 Redis 时**:写进程内存,读时从进程内存读。 +- 所有 Redis 键带 TTL(见 `redis_ttl.TTL_POSITIONS` / `TTL_BALANCE`)。 + +### 4. 价格与交易对信息 + +**文件**: `trading_system/binance_client.py` + +- **价格(get_ticker_24h)**:先读 Redis;REST 回源后只写 Redis;仅当 Redis 写入失败时才写进程内存(降级)。 +- **交易对信息(get_symbol_info)**:先读 Redis;从 DB/API 得到后写 Redis;仅当 Redis 不可用时才写入 `_symbol_info_cache`。 + +### 5. 其他 + +- 行情 / BookTicker / listenKey / 推荐结果等:凡写入 Redis 的均带 TTL,见各模块及 `redis_ttl.py`。 + +## 三、使用与运维 + +1. **保证 Redis/Valkey 可用**:配置好 `REDIS_URL`(如 AWS Valkey),确保交易服务能连上。 +2. **重启生效**:改缓存逻辑后需重启交易服务。 +3. **监控 Redis 内存**:在 Valkey 控制台或 `INFO memory` 查看内存;所有键有过期,不应无限增长。 +4. **进程内存**:正常情况下进程内缓存很少;仅 Redis 不可用时才用内存降级(条数/过期受控)。 + +## 四、预期效果 + +- 服务器进程内存占用明显下降(有 Redis 时基本不存大块缓存)。 +- Redis/Valkey 内所有业务键有过期,不会无限使用内存。 diff --git a/trading_system/binance_client.py b/trading_system/binance_client.py index 7472f6c..eeba5f7 100644 --- a/trading_system/binance_client.py +++ b/trading_system/binance_client.py @@ -201,8 +201,10 @@ class BinanceClient: self._last_request_time = {} # 记录每个API端点的最后请求时间 self._request_delay = 0.1 # 请求间隔(秒),避免频率限制 self._semaphore = asyncio.Semaphore(10) # 限制并发请求数 + # ⚠️ 内存优化:进程内存只做临时缓存,主要数据在 Redis self._price_cache: Dict[str, Dict] = {} # WebSocket价格缓存 {symbol: {price, volume, changePercent, timestamp}} - self._price_cache_ttl = 60 # 价格缓存有效期(秒) + self._price_cache_ttl = 30 # 价格缓存有效期(秒,减少进程内存占用) + self._price_cache_max_size = 50 # 最多保留 50 个 symbol 的价格缓存(减少进程内存占用) # 显示名 -> API symbol 映射(当交易所返回中文/非 ASCII 的 symbol 时,用 baseAsset+quoteAsset 作为下单用 symbol) self._display_to_api_symbol: Dict[str, str] = {} @@ -895,29 +897,25 @@ class BinanceClient: import time symbol = self._resolve_api_symbol(symbol) - # 1. 优先从WebSocket缓存读取 + # 全用 Redis:优先从 Redis 读;REST 后只写 Redis,不写进程内存(Redis 不可用时才写内存) + cache_key = f"ticker_24h:{symbol}" + cached = await self.redis_cache.get(cache_key) + if cached: + logger.debug(f"从Redis缓存获取 {symbol} 24小时行情数据") + return cached + # Redis 未命中时降级到进程内存(仅当 Redis 不可用时会有数据) if symbol in self._price_cache: cached = self._price_cache[symbol] cache_age = time.time() - cached.get('timestamp', 0) if cache_age < self._price_cache_ttl: - logger.debug(f"从WebSocket缓存获取 {symbol} 价格: {cached['price']:.8f} (缓存年龄: {cache_age:.1f}秒)") + logger.debug(f"从进程内存缓存获取 {symbol} 价格 (降级)") return { 'symbol': symbol, 'price': cached['price'], 'volume': cached.get('volume', 0), 'changePercent': cached.get('changePercent', 0) } - else: - logger.debug(f"{symbol} WebSocket缓存已过期 ({cache_age:.1f}秒 > {self._price_cache_ttl}秒)") - - # 2. 从 Redis 缓存读取 - cache_key = f"ticker_24h:{symbol}" - cached = await self.redis_cache.get(cache_key) - if cached: - logger.debug(f"从Redis缓存获取 {symbol} 24小时行情数据") - return cached - - # 3. 如果缓存不可用或过期,使用REST API(fallback) + self._price_cache.pop(symbol, None) logger.debug(f"{symbol} 未在缓存中,使用REST API获取") try: ticker = await self._rate_limited_request( @@ -934,16 +932,13 @@ class BinanceClient: 'volume': float(stats.get('quoteVolume', 0)), 'changePercent': float(stats.get('priceChangePercent', 0)) } - - # 更新 WebSocket 缓存 - self._price_cache[symbol] = { - **result, - 'timestamp': time.time() - } - - # 写入 Redis 缓存(TTL: 30秒) - await self.redis_cache.set(cache_key, result, ttl=30) - + # 只写 Redis;仅当 Redis 写入失败时才写进程内存(降级) + wrote_redis = await self.redis_cache.set(cache_key, result, ttl=30) + if not wrote_redis: + if len(self._price_cache) >= self._price_cache_max_size: + oldest_key = min(self._price_cache.keys(), key=lambda k: self._price_cache[k].get('timestamp', 0)) + self._price_cache.pop(oldest_key, None) + self._price_cache[symbol] = {**result, 'timestamp': time.time()} return result except BinanceAPIException as e: error_code = e.code if hasattr(e, 'code') else None @@ -1371,38 +1366,32 @@ class BinanceClient: """ symbol = self._resolve_api_symbol(symbol) loop = asyncio.get_event_loop() - # 1. 先检查内存缓存 - if symbol in self._symbol_info_cache: - cached_mem = self._symbol_info_cache[symbol] - # 兼容旧缓存:早期版本没有 tickSize/pricePrecision,容易触发 -4014/-1111 - if isinstance(cached_mem, dict) and ("tickSize" not in cached_mem or "pricePrecision" not in cached_mem): - try: - self._symbol_info_cache.pop(symbol, None) - except Exception: - pass - else: - return cached_mem - - # 2. 从 Redis 缓存读取 cache_key = f"symbol_info:{symbol}" + # 1. 优先从 Redis 读取(全用 Redis,不占进程内存) cached = await self.redis_cache.get(cache_key) if cached: logger.debug(f"从Redis缓存获取 {symbol} 交易对信息") - # 兼容旧缓存:早期版本没有 tickSize/pricePrecision,容易触发 -4014/-1111 if isinstance(cached, dict) and ("tickSize" not in cached or "pricePrecision" not in cached): logger.info(f"{symbol} symbol_info 缓存缺少 tickSize/pricePrecision,自动刷新一次") else: - self._symbol_info_cache[symbol] = cached return cached + # 2. 降级到进程内存(仅当 Redis 不可用时会有数据) + if symbol in self._symbol_info_cache: + cached_mem = self._symbol_info_cache[symbol] + if isinstance(cached_mem, dict) and ("tickSize" in cached_mem and "pricePrecision" in cached_mem): + return cached_mem + self._symbol_info_cache.pop(symbol, None) # 3. 尝试从 DB market_cache 读取 exchange_info 并解析本 symbol try: exchange_info_db = await loop.run_in_executor(None, lambda: _load_exchange_info_from_db(86400)) info = _parse_symbol_info_from_exchange_info(exchange_info_db, symbol) if exchange_info_db else None if info: - self._symbol_info_cache[symbol] = info if self.redis_cache: await self.redis_cache.set(cache_key, info, ttl=3600) + # Redis 写入成功则不占进程内存 + else: + self._symbol_info_cache[symbol] = info logger.debug(f"从 DB 缓存解析 {symbol} 交易对信息") return info except Exception as e: @@ -1484,11 +1473,9 @@ class BinanceClient: 'maxLeverage': int(max_leverage_supported) # 交易对支持的最大杠杆 } - # 写入 Redis 缓存(TTL: 1小时) - await self.redis_cache.set(cache_key, info, ttl=3600) - - # 同时更新内存缓存 - self._symbol_info_cache[symbol] = info + wrote = await self.redis_cache.set(cache_key, info, ttl=3600) + if not wrote: + self._symbol_info_cache[symbol] = info logger.debug(f"获取 {symbol} 精度信息: {info}") return info diff --git a/trading_system/kline_stream.py b/trading_system/kline_stream.py index 261247e..ba1af65 100644 --- a/trading_system/kline_stream.py +++ b/trading_system/kline_stream.py @@ -16,19 +16,28 @@ try: from .market_ws_leader import KEY_KLINE_PREFIX except ImportError: KEY_KLINE_PREFIX = "market:kline:" +try: + from .redis_ttl import TTL_KLINE_STREAM +except ImportError: + TTL_KLINE_STREAM = 600 # K线缓存:{ (symbol, interval): [kline1, kline2, ...] },最多保留 limit 根 _kline_cache: Dict[Tuple[str, str], List[List]] = {} _kline_cache_updated_at: Dict[Tuple[str, str], float] = {} _kline_cache_limit: Dict[Tuple[str, str], int] = {} # 每个 (symbol, interval) 的 limit # ⚠️ 内存优化:限制缓存总大小,避免内存无限增长(2 CPU 4G 服务器) -_MAX_CACHE_ENTRIES = 200 # 最多保留 200 个 (symbol, interval) 的缓存 +# 注意:主要数据在 Redis,进程内存只做临时缓存(减少大小) +_MAX_CACHE_ENTRIES = 50 # 最多保留 50 个 (symbol, interval) 的缓存(减少进程内存占用) _CACHE_CLEANUP_INTERVAL_SEC = 300 # 每 5 分钟清理一次过期缓存 -_CACHE_MAX_AGE_SEC = 600 # 缓存超过 10 分钟未更新则清理 +_CACHE_MAX_AGE_SEC = 300 # 缓存超过 5 分钟未更新则清理(更激进,优先用 Redis) def get_klines_from_cache(symbol: str, interval: str, limit: int = 50) -> Optional[List[List]]: - """从缓存返回K线数据(与 REST get_klines 格式兼容)。未订阅或数据不足时返回 None。""" + """ + 从缓存返回K线数据(与 REST get_klines 格式兼容)。 + ⚠️ 内存优化:优先从进程内存读取(快速),但进程内存缓存已限制大小,主要数据在 Redis。 + 未订阅或数据不足时返回 None。 + """ key = (symbol.upper(), interval.lower()) cached = _kline_cache.get(key) if not cached or len(cached) < limit: @@ -398,7 +407,8 @@ class KlineStream: pass async def _batch_write_redis(self): - """批量写入 Redis,减少写入频率""" + """批量写入 Redis;写入成功后从进程内存移除,以 Redis 为主、基本不占服务器内存""" + global _kline_cache, _kline_cache_updated_at, _kline_cache_limit if not self._redis_write_pending: return try: @@ -406,7 +416,16 @@ class KlineStream: self._redis_write_pending.clear() for rkey, (data, _) in pending.items(): try: - await self._redis_cache.set(rkey, data, ttl=600) + await self._redis_cache.set(rkey, data, ttl=TTL_KLINE_STREAM) + # 写入 Redis 后从进程内存移除,避免重复占用 + if rkey.startswith(KEY_KLINE_PREFIX): + suffix = rkey[len(KEY_KLINE_PREFIX):] + if ":" in suffix: + s, i = suffix.split(":", 1) + key = (s.upper(), i.lower()) + _kline_cache.pop(key, None) + _kline_cache_updated_at.pop(key, None) + _kline_cache_limit.pop(key, None) except Exception: pass except Exception: @@ -422,7 +441,7 @@ class KlineStream: async def _write_kline_to_redis(self, rkey: str, data: List[List]) -> None: try: if self._redis_cache: - await self._redis_cache.set(rkey, data, ttl=600) + await self._redis_cache.set(rkey, data, ttl=TTL_KLINE_STREAM) except Exception as e: logger.debug("KlineStream: 写入 Redis 失败 %s", e) diff --git a/trading_system/main.py b/trading_system/main.py index 9ff0b42..399075e 100644 --- a/trading_system/main.py +++ b/trading_system/main.py @@ -374,12 +374,13 @@ async def main(): logger.info(f"正在启动 User Data Stream(账号 {account_id})...") if await user_data_stream.start(): logger.info(f"✓ User Data Stream 已启动(账号 {account_id},订单/持仓/余额 WS 推送,30 分钟 keepalive)") - # 用当前 REST 结果播种缓存,后续由 WS 增量更新,业务可优先读缓存 + # 用当前 REST 结果播种缓存,后续由 WS 增量更新,业务可优先读缓存(Redis) try: - seed_balance_cache(balance) + redis_cache = getattr(client, "redis_cache", None) + await seed_balance_cache(balance, redis_cache) positions_seed = await client.get_open_positions() - seed_position_cache(positions_seed) - logger.info(f"✓ 已播种持仓/余额缓存(持仓 {len(positions_seed)} 个)") + await seed_position_cache(positions_seed, redis_cache) + logger.info(f"✓ 已播种持仓/余额缓存(持仓 {len(positions_seed)} 个,已写入 Redis)") except Exception as e: logger.warning(f"播种 WS 缓存失败(将仅用 REST): {e}") else: diff --git a/trading_system/position_manager.py b/trading_system/position_manager.py index 492bc35..e02f8f5 100644 --- a/trading_system/position_manager.py +++ b/trading_system/position_manager.py @@ -145,18 +145,20 @@ class PositionManager: self._last_auto_close_fail_log_ms: Dict[str, int] = {} async def _get_open_positions(self) -> List[Dict]: - """优先使用 User Data Stream 持仓缓存,无缓存或未启动时走 REST。""" + """优先使用 User Data Stream 持仓缓存(Redis),无缓存或未启动时走 REST。""" if get_stream_instance() is not None: min_notional = float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0) - cached = get_positions_from_cache(min_notional) + redis_cache = getattr(self.client, "redis_cache", None) + cached = await get_positions_from_cache(min_notional, redis_cache) if cached is not None: return cached return await self.client.get_open_positions() async def _get_account_balance(self) -> Dict: - """优先使用 User Data Stream 余额缓存,无缓存时走 REST。""" + """优先使用 User Data Stream 余额缓存(Redis),无缓存时走 REST。""" if get_stream_instance() is not None: - bal = get_balance_from_cache() + redis_cache = getattr(self.client, "redis_cache", None) + bal = await get_balance_from_cache(redis_cache) if bal is not None: return bal return await self.client.get_account_balance() diff --git a/trading_system/redis_cache.py b/trading_system/redis_cache.py index a083fae..013994a 100644 --- a/trading_system/redis_cache.py +++ b/trading_system/redis_cache.py @@ -1,9 +1,11 @@ """ -Redis 缓存管理器 - 支持 TLS 连接 +Redis 缓存管理器 - 支持 TLS 连接。 +全量缓存以 Redis 为主;Redis 不可用时降级到内存缓存,内存缓存有最大条数与 TTL,避免无限增长。 """ import json import logging -from typing import Optional, Any, Dict, List +import time +from typing import Optional, Any, Dict, List, Tuple try: # 使用 redis-py 4.2+ 的异步客户端(替代 aioredis) @@ -76,7 +78,10 @@ class RedisCache: self.username = username self.password = password self.redis: Optional[Redis] = None - self._memory_cache: Dict[str, Any] = {} # 降级到内存缓存 + # 降级到内存缓存:key -> (value, 写入时间戳),有最大条数与 TTL,避免无限增长 + self._memory_cache: Dict[str, Tuple[Any, float]] = {} + self._memory_cache_max_size = 200 # 最多 200 条 + self._memory_cache_ttl_sec = 300 # 单条 5 分钟过期 self._connected = False async def connect(self): @@ -205,10 +210,16 @@ class RedisCache: if not self._connected: await self.connect() - # 降级到内存缓存 + # 降级到内存缓存(带 TTL 清理;tuple 为 (value, ts),非 tuple 如 hash 的 dict 不做过期) if key in self._memory_cache: - return self._memory_cache[key] - + entry = self._memory_cache[key] + if isinstance(entry, tuple) and len(entry) == 2: + val, ts = entry + if time.time() - ts <= self._memory_cache_ttl_sec: + return val + del self._memory_cache[key] + else: + return entry return None async def set(self, key: str, value: Any, ttl: int = 3600): @@ -240,8 +251,14 @@ class RedisCache: except: pass - # 降级到内存缓存(不设置 TTL,因为内存缓存不支持) - self._memory_cache[key] = value + # 降级到内存缓存:限制条数 + TTL,避免无限增长 + def _ts(k): + v = self._memory_cache.get(k) + return v[1] if isinstance(v, tuple) and len(v) == 2 else 0 + while len(self._memory_cache) >= self._memory_cache_max_size and self._memory_cache: + oldest_key = min(self._memory_cache.keys(), key=_ts) + del self._memory_cache[oldest_key] + self._memory_cache[key] = (value, time.time()) return False async def get_int(self, key: str, default: int = 0) -> int: @@ -257,7 +274,9 @@ class RedisCache: return int(default or 0) if key in self._memory_cache: try: - return int(self._memory_cache.get(key) or 0) + entry = self._memory_cache[key] + val = entry[0] if isinstance(entry, tuple) else entry + return int(val or 0) except Exception: return int(default or 0) except Exception: @@ -288,14 +307,22 @@ class RedisCache: except Exception as e: logger.debug(f"Redis incr失败 {key}: {e}") - # 内存兜底(不做 TTL) + # 内存兜底:限制条数 + TTL cur = 0 try: - cur = int(self._memory_cache.get(key) or 0) + if key in self._memory_cache: + entry = self._memory_cache[key] + cur = int(entry[0] if isinstance(entry, tuple) else entry) except Exception: cur = 0 cur += inc - self._memory_cache[key] = cur + def _ts(k): + v = self._memory_cache.get(k) + return v[1] if isinstance(v, tuple) and len(v) == 2 else 0 + while len(self._memory_cache) >= self._memory_cache_max_size and key not in self._memory_cache and self._memory_cache: + oldest_key = min(self._memory_cache.keys(), key=_ts) + del self._memory_cache[oldest_key] + self._memory_cache[key] = (cur, time.time()) return int(cur) async def delete(self, key: str): diff --git a/trading_system/redis_ttl.py b/trading_system/redis_ttl.py new file mode 100644 index 0000000..317afc8 --- /dev/null +++ b/trading_system/redis_ttl.py @@ -0,0 +1,40 @@ +""" +Redis 缓存键与 TTL 统一配置:所有键均带过期时间,避免 Valkey/Redis 无限增长。 +全量缓存以 Redis 为主,进程内存仅作降级或最小缓冲。 +""" +# 键前缀(便于排查与清理) +PREFIX_ATS = "ats:" +PREFIX_MARKET = "market:" +PREFIX_KLINES = "klines:" +PREFIX_TICKER = "ticker_24h:" +PREFIX_SYMBOL_INFO = "symbol_info:" +PREFIX_LISTEN_KEY = "listen_key:" +PREFIX_LEADER = "market_ws_leader" + +# TTL(秒)- 所有写入 Redis 的键必须带 TTL,禁止无过期 +TTL_LEADER = 30 +TTL_TICKER_24H = 120 +TTL_BOOK_TICKER = 30 +TTL_KLINE_STREAM = 600 # 10 分钟(WS 实时写入的 K 线) +TTL_POSITIONS = 300 # 5 分钟(持仓缓存) +TTL_BALANCE = 300 # 5 分钟(余额缓存) +TTL_TICKER_SYMBOL = 30 # 单 symbol 24h 行情 +TTL_SYMBOL_INFO = 3600 # 1 小时(交易对信息) +TTL_KLINES_REST = 1800 # REST 拉取的 K 线默认 30 分钟 +TTL_KLINES_REST_OLD = 300 # 旧格式 klines:{s}:{i}:{limit} 默认 5 分钟 +TTL_LISTEN_KEY = 55 * 60 # 55 分钟(listenKey 缓存) +TTL_TREND_STATE = 3600 +TTL_INDICATORS = 30 +TTL_RECO_SNAPSHOT = 7200 +TTL_RECO_ITEM = 3600 +TTL_LOCK_RECO = 10 + +# K 线按周期 TTL(REST 拉取后写入共享缓存) +TTL_KLINES_BY_INTERVAL = { + '1m': 60, '3m': 120, '5m': 180, '15m': 300, '30m': 600, + '1h': 900, '2h': 1800, '4h': 3600, '6h': 5400, '8h': 7200, '12h': 10800, '1d': 21600, +} +TTL_KLINES_OLD_BY_INTERVAL = { + '1m': 10, '3m': 20, '5m': 30, '15m': 60, '30m': 120, + '1h': 300, '2h': 600, '4h': 900, '6h': 1200, '8h': 1800, '12h': 2400, '1d': 3600, +} diff --git a/trading_system/risk_manager.py b/trading_system/risk_manager.py index c838ed5..731d87b 100644 --- a/trading_system/risk_manager.py +++ b/trading_system/risk_manager.py @@ -25,17 +25,22 @@ def _get_stream_instance(): except Exception: return None -def _get_balance_from_cache(): +async def _get_balance_from_cache(client: Optional[BinanceClient] = None): + """从缓存获取余额(优先 Redis,降级到进程内存)""" try: from .user_data_stream import get_balance_from_cache - return get_balance_from_cache() + redis_cache = getattr(client, "redis_cache", None) if client else None + return await get_balance_from_cache(redis_cache) except Exception: return None -def _get_positions_from_cache(): +async def _get_positions_from_cache(client: Optional[BinanceClient] = None): + """从缓存获取持仓(优先 Redis,降级到进程内存)""" try: from .user_data_stream import get_positions_from_cache - return get_positions_from_cache(float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0)) + redis_cache = getattr(client, "redis_cache", None) if client else None + min_notional = float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0) + return await get_positions_from_cache(min_notional, redis_cache) except Exception: return None @@ -71,8 +76,8 @@ class RiskManager: try: logger.info(f"检查 {symbol} 单笔仓位大小...") - # 获取账户余额(优先 WS 缓存) - balance = _get_balance_from_cache() if _get_stream_instance() else None + # 获取账户余额(优先 WS 缓存,Redis) + balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None if balance is None: balance = await self.client.get_account_balance() available_balance = balance.get('available', 0) @@ -162,8 +167,8 @@ class RiskManager: 是否通过检查 """ try: - # 获取当前持仓(优先 WS 缓存) - positions = _get_positions_from_cache() if _get_stream_instance() else None + # 获取当前持仓(优先 WS 缓存,Redis) + positions = await _get_positions_from_cache(self.client) if _get_stream_instance() else None if positions is None: positions = await self.client.get_open_positions() @@ -194,8 +199,8 @@ class RiskManager: # 加上新仓位 total_with_new = total_margin_value + new_position_margin - # 获取账户余额(优先 WS 缓存) - balance = _get_balance_from_cache() if _get_stream_instance() else None + # 获取账户余额(优先 WS 缓存,Redis) + balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None if balance is None: balance = await self.client.get_account_balance() total_balance = balance.get('total', 0) @@ -430,8 +435,8 @@ class RiskManager: try: logger.info(f"开始计算 {symbol} 的仓位大小...") - # 获取账户余额(优先 WS 缓存) - balance = _get_balance_from_cache() if _get_stream_instance() else None + # 获取账户余额(优先 WS 缓存,Redis) + balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None if balance is None: balance = await self.client.get_account_balance() available_balance = balance.get('available', 0) @@ -818,7 +823,7 @@ class RiskManager: return False # 检查是否已有持仓 / 总持仓数量限制(优先 WS 缓存) - positions = _get_positions_from_cache() if _get_stream_instance() else None + positions = await _get_positions_from_cache(self.client) if _get_stream_instance() else None if positions is None: positions = await self.client.get_open_positions() try: diff --git a/trading_system/user_data_stream.py b/trading_system/user_data_stream.py index 7de9ffb..695a9ff 100644 --- a/trading_system/user_data_stream.py +++ b/trading_system/user_data_stream.py @@ -39,39 +39,103 @@ def get_stream_instance() -> Optional["UserDataStream"]: return _stream_instance -def seed_position_cache(positions: List[Dict]) -> None: - """用 REST 全量持仓结果填充缓存(启动时调用一次,之后由 ACCOUNT_UPDATE 增量更新)。""" +try: + from .redis_ttl import TTL_POSITIONS +except ImportError: + TTL_POSITIONS = 300 +try: + from .redis_ttl import TTL_BALANCE +except ImportError: + TTL_BALANCE = 300 + + +async def seed_position_cache(positions: List[Dict], redis_cache: Any = None) -> None: + """用 REST 全量持仓结果填充缓存。有 Redis 时只写 Redis、不占进程内存;无 Redis 时写进程内存。""" global _position_updates_cache, _position_cache_seeded - _position_updates_cache.clear() _position_cache_seeded = True + positions_list = [] for pos in positions or []: symbol = (pos.get("symbol") or "").strip() if not symbol: continue amt = float(pos.get("positionAmt") or 0) - _position_updates_cache[symbol] = [{ - "s": symbol, - "pa": amt, - "ep": str(pos.get("entryPrice") or "0"), - "up": str(pos.get("unRealizedProfit") or "0"), - "ps": pos.get("positionSide") or "BOTH", - }] - logger.debug(f"UserDataStream: 已填充持仓缓存 {len(_position_updates_cache)} 个 symbol") + if not redis_cache: + if symbol not in _position_updates_cache: + _position_updates_cache[symbol] = [] + _position_updates_cache[symbol] = [{ + "s": symbol, + "pa": amt, + "ep": str(pos.get("entryPrice") or "0"), + "up": str(pos.get("unRealizedProfit") or "0"), + "ps": pos.get("positionSide") or "BOTH", + }] + if amt != 0: + positions_list.append({ + "symbol": symbol, + "positionAmt": amt, + "entryPrice": float(pos.get("entryPrice") or 0), + "markPrice": float(pos.get("markPrice") or 0), + "unRealizedProfit": float(pos.get("unRealizedProfit") or 0), + "leverage": int(pos.get("leverage") or 0), + }) + if redis_cache: + _position_updates_cache.clear() + if positions_list: + try: + await redis_cache.set("ats:positions:cache", positions_list, ttl=TTL_POSITIONS) + except Exception as e: + logger.debug(f"写入持仓缓存到 Redis 失败: {e}") + logger.debug(f"UserDataStream: 已填充持仓缓存(Redis=%s)", bool(redis_cache)) -def seed_balance_cache(balance: Dict[str, Any]) -> None: - """用 REST 余额结果填充缓存(启动时调用一次,之后由 ACCOUNT_UPDATE 增量更新)。""" +async def seed_balance_cache(balance: Dict[str, Any], redis_cache: Any = None) -> None: + """用 REST 余额结果填充缓存。有 Redis 时只写 Redis、不占进程内存;无 Redis 时写进程内存。""" global _balance_updates_cache, _balance_cache_seeded _balance_cache_seeded = True if balance and isinstance(balance, dict): wb = balance.get("walletBalance") or balance.get("total") or 0 av = balance.get("availableBalance") or balance.get("available") or wb - _balance_updates_cache["USDT"] = {"wb": str(wb), "cw": str(av), "bc": "0"} - logger.debug("UserDataStream: 已填充余额缓存 (USDT)") + balance_data = {"wb": str(wb), "cw": str(av), "bc": "0"} + if redis_cache: + try: + await redis_cache.set("ats:balance:cache:USDT", balance_data, ttl=TTL_BALANCE) + except Exception as e: + logger.debug(f"写入余额缓存到 Redis 失败: {e}") + else: + _balance_updates_cache["USDT"] = balance_data + logger.debug("UserDataStream: 已填充余额缓存 (USDT, Redis=%s)", bool(redis_cache)) -def get_positions_from_cache(min_notional: float = 1.0) -> Optional[List[Dict]]: - """将持仓缓存转为与 REST get_open_positions 一致的列表格式;未播种时返回 None(业务应走 REST)。""" +async def get_positions_from_cache(min_notional: float = 1.0, redis_cache: Any = None) -> Optional[List[Dict]]: + """ + 将持仓缓存转为与 REST get_open_positions 一致的列表格式;未播种时返回 None(业务应走 REST)。 + ⚠️ 内存优化:优先从 Redis 读取,减少进程内存占用。 + """ + # 优先从 Redis 读取(多进程共享) + if redis_cache: + try: + redis_key = "ats:positions:cache" + cached = await redis_cache.get(redis_key) + if cached and isinstance(cached, list): + # 过滤最小名义价值 + filtered = [] + for pos in cached: + try: + pa = float(pos.get("positionAmt") or 0) + ep = float(pos.get("entryPrice") or 0) + if pa == 0: + continue + if min_notional > 0 and abs(pa) * ep < min_notional: + continue + filtered.append(pos) + except Exception: + continue + if filtered: + return filtered + except Exception as e: + logger.debug(f"从 Redis 读取持仓缓存失败: {e}") + + # 降级到进程内存缓存 if not _position_cache_seeded: return None out = [] @@ -97,8 +161,27 @@ def get_positions_from_cache(min_notional: float = 1.0) -> Optional[List[Dict]]: return out -def get_balance_from_cache() -> Optional[Dict[str, Any]]: - """从缓存返回 USDT 余额(与 REST get_account_balance 结构兼容);未播种或无缓存时返回 None。""" +async def get_balance_from_cache(redis_cache: Any = None) -> Optional[Dict[str, Any]]: + """ + 从缓存返回 USDT 余额(与 REST get_account_balance 结构兼容);未播种或无缓存时返回 None。 + ⚠️ 内存优化:优先从 Redis 读取,减少进程内存占用。 + """ + # 优先从 Redis 读取(多进程共享) + if redis_cache: + try: + redis_key = "ats:balance:cache:USDT" + cached = await redis_cache.get(redis_key) + if cached and isinstance(cached, dict): + try: + wb = float(cached.get("wb") or cached.get("total") or 0) + cw = float(cached.get("cw") or cached.get("available") or wb) + return {"ok": True, "available": cw, "total": wb, "margin": wb} + except (TypeError, ValueError): + pass + except Exception as e: + logger.debug(f"从 Redis 读取余额缓存失败: {e}") + + # 降级到进程内存缓存 if not _balance_cache_seeded: return None u = _balance_updates_cache.get("USDT") @@ -455,29 +538,73 @@ class UserDataStream: logger.warning(f"UserDataStream: set_exit_order_id_for_open_trade 失败 {ex}") def _on_account_update(self, a: Dict): - # 文档: a.B = 余额数组,每项 a=资产, wb=钱包余额, cw=除逐仓外的钱包余额, bc=余额变化量 - # 文档: a.P = 持仓信息数组,每项 s=symbol, pa=仓位, ep=入仓价, ps=LONG/SHORT/BOTH 等 + # 文档: a.B = 余额数组,a.P = 持仓信息数组。有 Redis 时只写 Redis、不写进程内存。 global _position_updates_cache, _balance_updates_cache + redis_cache = getattr(self.client, "redis_cache", None) + B = a.get("B") if isinstance(B, list) and B: for b in B: asset = (b.get("a") or "").strip() if asset: - _balance_updates_cache[asset] = { + balance_data = { "wb": b.get("wb"), "cw": b.get("cw"), "bc": b.get("bc"), } + if redis_cache: + if asset == "USDT": + asyncio.create_task(self._write_balance_to_redis(asset, balance_data)) + else: + _balance_updates_cache[asset] = balance_data logger.debug(f"UserDataStream: ACCOUNT_UPDATE 余额资产数 {len(B)}") P = a.get("P") if isinstance(P, list) and P: + positions_list = [] for p in P: s = (p.get("s") or "").strip() if s: - if s not in _position_updates_cache: - _position_updates_cache[s] = [] - _position_updates_cache[s] = [p] + if not redis_cache: + if s not in _position_updates_cache: + _position_updates_cache[s] = [] + _position_updates_cache[s] = [p] + try: + pa = float(p.get("pa") or 0) + ep = float(p.get("ep") or 0) + if pa != 0: + positions_list.append({ + "symbol": s, + "positionAmt": pa, + "entryPrice": ep, + "markPrice": float(p.get("markPrice") or 0), + "unRealizedProfit": float(p.get("up") or 0), + "leverage": int(p.get("leverage") or 0), + }) + except Exception: + pass + if redis_cache and positions_list: + asyncio.create_task(self._write_positions_to_redis(positions_list)) logger.debug(f"UserDataStream: ACCOUNT_UPDATE 持仓数 {len(P)}") + + async def _write_balance_to_redis(self, asset: str, balance_data: Dict): + """写入余额缓存到 Redis(带 TTL,避免无限增长)""" + try: + redis_cache = getattr(self.client, "redis_cache", None) + if redis_cache: + redis_key = f"ats:balance:cache:{asset}" + await redis_cache.set(redis_key, balance_data, ttl=TTL_BALANCE) + except Exception as e: + logger.debug(f"写入余额缓存到 Redis 失败: {e}") + + async def _write_positions_to_redis(self, positions_list: List[Dict]): + """写入持仓缓存到 Redis(带 TTL,避免无限增长)""" + try: + redis_cache = getattr(self.client, "redis_cache", None) + if redis_cache: + redis_key = "ats:positions:cache" + await redis_cache.set(redis_key, positions_list, ttl=TTL_POSITIONS) + except Exception as e: + logger.debug(f"写入持仓缓存到 Redis 失败: {e}") def _on_algo_update(self, o: Dict): # 条件单交易更新推送:X=TRIGGERED/FINISHED 且 ai=触发后普通订单 id 时,回写 open 记录的 exit_order_id