feat(redis_cache, kline_stream, user_data_stream, risk_manager): 优化缓存机制与内存管理

在多个模块中引入 Redis 作为主要缓存机制,减少进程内存占用。更新 `binance_client.py`、`kline_stream.py`、`user_data_stream.py` 和 `risk_manager.py`,实现优先从 Redis 读取数据,降级到内存缓存。调整缓存 TTL 和最大条数,确保系统稳定性与性能。此改动提升了数据访问效率,优化了内存使用,增强了系统的整体性能。
This commit is contained in:
薇薇安 2026-02-19 00:19:54 +08:00
parent 80872231a5
commit 59e25558cd
9 changed files with 388 additions and 109 deletions

View File

@ -0,0 +1,71 @@
# 缓存策略:全用 Redis基本不占服务器内存
## 目标
- **全量缓存以 Redis/Valkey 为主**:有 Redis 时只读只写 Redis进程内基本不保留缓存减轻服务器内存压力。
- **Redis 内容全部带过期时间**:所有写入 Redis 的键均设置 TTL不在 Redis 内无限增长。
## 一、Redis 键与 TTL 统一配置
所有 TTL 与键前缀集中在 `trading_system/redis_ttl.py` 中定义,**禁止**在 Redis 中写入无过期时间的业务键。
| 用途 | 键/前缀示例 | TTL | 说明 |
|----------------|--------------------------|-----------|----------------|
| 持仓缓存 | `ats:positions:cache` | 300 | 5 分钟 |
| 余额缓存 | `ats:balance:cache:USDT` | 300 | 5 分钟 |
| K 线WS 写入)| `market:kline:{s}:{i}` | 600 | 10 分钟 |
| 24h 行情 | `ticker_24h:{symbol}` | 30 | 单 symbol |
| 全市场 24h | `market:ticker_24h` | 120 | 2 分钟 |
| BookTicker | `market:book_ticker` | 30 | 30 秒 |
| 交易对信息 | `symbol_info:{symbol}` | 3600 | 1 小时 |
| listenKey 缓存 | `listen_key:*` | 3300 | 55 分钟 |
| 市场 WS Leader | `market_ws_leader` | 30 | 选主续期 |
更多见 `trading_system/redis_ttl.py`(含 K 线按 interval 的 TTL 等)。
## 二、各模块行为
### 1. RedisCache 降级内存缓存
**文件**: `trading_system/redis_cache.py`
- Redis 不可用时降级到进程内存。
- 内存缓存**有上限**:最多 200 条;单条** 5 分钟**过期,过期或满时淘汰最久未用,避免无限增长。
### 2. K 线
**文件**: `trading_system/kline_stream.py`
- Leader 进程WebSocket 收到 K 线后写入 RedisTTL 见 `redis_ttl.TTL_KLINE_STREAM`**写入成功后从进程内 `_kline_cache` 删除该 key**,以 Redis 为准、基本不占服务器内存。
- 非 Leader / 读路径:从 Redis 读(`get_klines_from_redis`);进程内仅保留未刷写 Redis 的少量缓冲。
### 3. 持仓 / 余额
**文件**: `trading_system/user_data_stream.py`
- **有 Redis 时**:只写 Redis**不写** `_position_updates_cache` / `_balance_updates_cache`;读时优先从 Redis 读。
- **无 Redis 时**:写进程内存,读时从进程内存读。
- 所有 Redis 键带 TTL`redis_ttl.TTL_POSITIONS` / `TTL_BALANCE`)。
### 4. 价格与交易对信息
**文件**: `trading_system/binance_client.py`
- **价格get_ticker_24h**:先读 RedisREST 回源后只写 Redis仅当 Redis 写入失败时才写进程内存(降级)。
- **交易对信息get_symbol_info**:先读 Redis从 DB/API 得到后写 Redis仅当 Redis 不可用时才写入 `_symbol_info_cache`
### 5. 其他
- 行情 / BookTicker / listenKey / 推荐结果等:凡写入 Redis 的均带 TTL见各模块及 `redis_ttl.py`
## 三、使用与运维
1. **保证 Redis/Valkey 可用**:配置好 `REDIS_URL`(如 AWS Valkey确保交易服务能连上。
2. **重启生效**:改缓存逻辑后需重启交易服务。
3. **监控 Redis 内存**:在 Valkey 控制台或 `INFO memory` 查看内存;所有键有过期,不应无限增长。
4. **进程内存**:正常情况下进程内缓存很少;仅 Redis 不可用时才用内存降级(条数/过期受控)。
## 四、预期效果
- 服务器进程内存占用明显下降(有 Redis 时基本不存大块缓存)。
- Redis/Valkey 内所有业务键有过期,不会无限使用内存。

View File

@ -201,8 +201,10 @@ class BinanceClient:
self._last_request_time = {} # 记录每个API端点的最后请求时间
self._request_delay = 0.1 # 请求间隔(秒),避免频率限制
self._semaphore = asyncio.Semaphore(10) # 限制并发请求数
# ⚠️ 内存优化:进程内存只做临时缓存,主要数据在 Redis
self._price_cache: Dict[str, Dict] = {} # WebSocket价格缓存 {symbol: {price, volume, changePercent, timestamp}}
self._price_cache_ttl = 60 # 价格缓存有效期(秒)
self._price_cache_ttl = 30 # 价格缓存有效期(秒,减少进程内存占用)
self._price_cache_max_size = 50 # 最多保留 50 个 symbol 的价格缓存(减少进程内存占用)
# 显示名 -> API symbol 映射(当交易所返回中文/非 ASCII 的 symbol 时,用 baseAsset+quoteAsset 作为下单用 symbol
self._display_to_api_symbol: Dict[str, str] = {}
@ -895,29 +897,25 @@ class BinanceClient:
import time
symbol = self._resolve_api_symbol(symbol)
# 1. 优先从WebSocket缓存读取
# 全用 Redis优先从 Redis 读REST 后只写 Redis不写进程内存Redis 不可用时才写内存)
cache_key = f"ticker_24h:{symbol}"
cached = await self.redis_cache.get(cache_key)
if cached:
logger.debug(f"从Redis缓存获取 {symbol} 24小时行情数据")
return cached
# Redis 未命中时降级到进程内存(仅当 Redis 不可用时会有数据)
if symbol in self._price_cache:
cached = self._price_cache[symbol]
cache_age = time.time() - cached.get('timestamp', 0)
if cache_age < self._price_cache_ttl:
logger.debug(f"从WebSocket缓存获取 {symbol} 价格: {cached['price']:.8f} (缓存年龄: {cache_age:.1f}秒)")
logger.debug(f"进程内存缓存获取 {symbol} 价格 (降级)")
return {
'symbol': symbol,
'price': cached['price'],
'volume': cached.get('volume', 0),
'changePercent': cached.get('changePercent', 0)
}
else:
logger.debug(f"{symbol} WebSocket缓存已过期 ({cache_age:.1f}秒 > {self._price_cache_ttl}秒)")
# 2. 从 Redis 缓存读取
cache_key = f"ticker_24h:{symbol}"
cached = await self.redis_cache.get(cache_key)
if cached:
logger.debug(f"从Redis缓存获取 {symbol} 24小时行情数据")
return cached
# 3. 如果缓存不可用或过期使用REST APIfallback
self._price_cache.pop(symbol, None)
logger.debug(f"{symbol} 未在缓存中使用REST API获取")
try:
ticker = await self._rate_limited_request(
@ -934,16 +932,13 @@ class BinanceClient:
'volume': float(stats.get('quoteVolume', 0)),
'changePercent': float(stats.get('priceChangePercent', 0))
}
# 更新 WebSocket 缓存
self._price_cache[symbol] = {
**result,
'timestamp': time.time()
}
# 写入 Redis 缓存TTL: 30秒
await self.redis_cache.set(cache_key, result, ttl=30)
# 只写 Redis仅当 Redis 写入失败时才写进程内存(降级)
wrote_redis = await self.redis_cache.set(cache_key, result, ttl=30)
if not wrote_redis:
if len(self._price_cache) >= self._price_cache_max_size:
oldest_key = min(self._price_cache.keys(), key=lambda k: self._price_cache[k].get('timestamp', 0))
self._price_cache.pop(oldest_key, None)
self._price_cache[symbol] = {**result, 'timestamp': time.time()}
return result
except BinanceAPIException as e:
error_code = e.code if hasattr(e, 'code') else None
@ -1371,38 +1366,32 @@ class BinanceClient:
"""
symbol = self._resolve_api_symbol(symbol)
loop = asyncio.get_event_loop()
# 1. 先检查内存缓存
if symbol in self._symbol_info_cache:
cached_mem = self._symbol_info_cache[symbol]
# 兼容旧缓存:早期版本没有 tickSize/pricePrecision容易触发 -4014/-1111
if isinstance(cached_mem, dict) and ("tickSize" not in cached_mem or "pricePrecision" not in cached_mem):
try:
self._symbol_info_cache.pop(symbol, None)
except Exception:
pass
else:
return cached_mem
# 2. 从 Redis 缓存读取
cache_key = f"symbol_info:{symbol}"
# 1. 优先从 Redis 读取(全用 Redis不占进程内存
cached = await self.redis_cache.get(cache_key)
if cached:
logger.debug(f"从Redis缓存获取 {symbol} 交易对信息")
# 兼容旧缓存:早期版本没有 tickSize/pricePrecision容易触发 -4014/-1111
if isinstance(cached, dict) and ("tickSize" not in cached or "pricePrecision" not in cached):
logger.info(f"{symbol} symbol_info 缓存缺少 tickSize/pricePrecision自动刷新一次")
else:
self._symbol_info_cache[symbol] = cached
return cached
# 2. 降级到进程内存(仅当 Redis 不可用时会有数据)
if symbol in self._symbol_info_cache:
cached_mem = self._symbol_info_cache[symbol]
if isinstance(cached_mem, dict) and ("tickSize" in cached_mem and "pricePrecision" in cached_mem):
return cached_mem
self._symbol_info_cache.pop(symbol, None)
# 3. 尝试从 DB market_cache 读取 exchange_info 并解析本 symbol
try:
exchange_info_db = await loop.run_in_executor(None, lambda: _load_exchange_info_from_db(86400))
info = _parse_symbol_info_from_exchange_info(exchange_info_db, symbol) if exchange_info_db else None
if info:
self._symbol_info_cache[symbol] = info
if self.redis_cache:
await self.redis_cache.set(cache_key, info, ttl=3600)
# Redis 写入成功则不占进程内存
else:
self._symbol_info_cache[symbol] = info
logger.debug(f"从 DB 缓存解析 {symbol} 交易对信息")
return info
except Exception as e:
@ -1484,11 +1473,9 @@ class BinanceClient:
'maxLeverage': int(max_leverage_supported) # 交易对支持的最大杠杆
}
# 写入 Redis 缓存TTL: 1小时
await self.redis_cache.set(cache_key, info, ttl=3600)
# 同时更新内存缓存
self._symbol_info_cache[symbol] = info
wrote = await self.redis_cache.set(cache_key, info, ttl=3600)
if not wrote:
self._symbol_info_cache[symbol] = info
logger.debug(f"获取 {symbol} 精度信息: {info}")
return info

View File

@ -16,19 +16,28 @@ try:
from .market_ws_leader import KEY_KLINE_PREFIX
except ImportError:
KEY_KLINE_PREFIX = "market:kline:"
try:
from .redis_ttl import TTL_KLINE_STREAM
except ImportError:
TTL_KLINE_STREAM = 600
# K线缓存{ (symbol, interval): [kline1, kline2, ...] },最多保留 limit 根
_kline_cache: Dict[Tuple[str, str], List[List]] = {}
_kline_cache_updated_at: Dict[Tuple[str, str], float] = {}
_kline_cache_limit: Dict[Tuple[str, str], int] = {} # 每个 (symbol, interval) 的 limit
# ⚠️ 内存优化限制缓存总大小避免内存无限增长2 CPU 4G 服务器)
_MAX_CACHE_ENTRIES = 200 # 最多保留 200 个 (symbol, interval) 的缓存
# 注意:主要数据在 Redis进程内存只做临时缓存减少大小
_MAX_CACHE_ENTRIES = 50 # 最多保留 50 个 (symbol, interval) 的缓存(减少进程内存占用)
_CACHE_CLEANUP_INTERVAL_SEC = 300 # 每 5 分钟清理一次过期缓存
_CACHE_MAX_AGE_SEC = 600 # 缓存超过 10 分钟未更新则清理
_CACHE_MAX_AGE_SEC = 300 # 缓存超过 5 分钟未更新则清理(更激进,优先用 Redis
def get_klines_from_cache(symbol: str, interval: str, limit: int = 50) -> Optional[List[List]]:
"""从缓存返回K线数据与 REST get_klines 格式兼容)。未订阅或数据不足时返回 None。"""
"""
从缓存返回K线数据 REST get_klines 格式兼容
内存优化优先从进程内存读取快速但进程内存缓存已限制大小主要数据在 Redis
未订阅或数据不足时返回 None
"""
key = (symbol.upper(), interval.lower())
cached = _kline_cache.get(key)
if not cached or len(cached) < limit:
@ -398,7 +407,8 @@ class KlineStream:
pass
async def _batch_write_redis(self):
"""批量写入 Redis减少写入频率"""
"""批量写入 Redis写入成功后从进程内存移除以 Redis 为主、基本不占服务器内存"""
global _kline_cache, _kline_cache_updated_at, _kline_cache_limit
if not self._redis_write_pending:
return
try:
@ -406,7 +416,16 @@ class KlineStream:
self._redis_write_pending.clear()
for rkey, (data, _) in pending.items():
try:
await self._redis_cache.set(rkey, data, ttl=600)
await self._redis_cache.set(rkey, data, ttl=TTL_KLINE_STREAM)
# 写入 Redis 后从进程内存移除,避免重复占用
if rkey.startswith(KEY_KLINE_PREFIX):
suffix = rkey[len(KEY_KLINE_PREFIX):]
if ":" in suffix:
s, i = suffix.split(":", 1)
key = (s.upper(), i.lower())
_kline_cache.pop(key, None)
_kline_cache_updated_at.pop(key, None)
_kline_cache_limit.pop(key, None)
except Exception:
pass
except Exception:
@ -422,7 +441,7 @@ class KlineStream:
async def _write_kline_to_redis(self, rkey: str, data: List[List]) -> None:
try:
if self._redis_cache:
await self._redis_cache.set(rkey, data, ttl=600)
await self._redis_cache.set(rkey, data, ttl=TTL_KLINE_STREAM)
except Exception as e:
logger.debug("KlineStream: 写入 Redis 失败 %s", e)

View File

@ -374,12 +374,13 @@ async def main():
logger.info(f"正在启动 User Data Stream账号 {account_id}...")
if await user_data_stream.start():
logger.info(f"✓ User Data Stream 已启动(账号 {account_id},订单/持仓/余额 WS 推送30 分钟 keepalive")
# 用当前 REST 结果播种缓存,后续由 WS 增量更新,业务可优先读缓存
# 用当前 REST 结果播种缓存,后续由 WS 增量更新,业务可优先读缓存Redis
try:
seed_balance_cache(balance)
redis_cache = getattr(client, "redis_cache", None)
await seed_balance_cache(balance, redis_cache)
positions_seed = await client.get_open_positions()
seed_position_cache(positions_seed)
logger.info(f"✓ 已播种持仓/余额缓存(持仓 {len(positions_seed)}")
await seed_position_cache(positions_seed, redis_cache)
logger.info(f"✓ 已播种持仓/余额缓存(持仓 {len(positions_seed)},已写入 Redis")
except Exception as e:
logger.warning(f"播种 WS 缓存失败(将仅用 REST: {e}")
else:

View File

@ -145,18 +145,20 @@ class PositionManager:
self._last_auto_close_fail_log_ms: Dict[str, int] = {}
async def _get_open_positions(self) -> List[Dict]:
"""优先使用 User Data Stream 持仓缓存,无缓存或未启动时走 REST。"""
"""优先使用 User Data Stream 持仓缓存Redis,无缓存或未启动时走 REST。"""
if get_stream_instance() is not None:
min_notional = float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0)
cached = get_positions_from_cache(min_notional)
redis_cache = getattr(self.client, "redis_cache", None)
cached = await get_positions_from_cache(min_notional, redis_cache)
if cached is not None:
return cached
return await self.client.get_open_positions()
async def _get_account_balance(self) -> Dict:
"""优先使用 User Data Stream 余额缓存,无缓存时走 REST。"""
"""优先使用 User Data Stream 余额缓存Redis,无缓存时走 REST。"""
if get_stream_instance() is not None:
bal = get_balance_from_cache()
redis_cache = getattr(self.client, "redis_cache", None)
bal = await get_balance_from_cache(redis_cache)
if bal is not None:
return bal
return await self.client.get_account_balance()

View File

@ -1,9 +1,11 @@
"""
Redis 缓存管理器 - 支持 TLS 连接
Redis 缓存管理器 - 支持 TLS 连接
全量缓存以 Redis 为主Redis 不可用时降级到内存缓存内存缓存有最大条数与 TTL避免无限增长
"""
import json
import logging
from typing import Optional, Any, Dict, List
import time
from typing import Optional, Any, Dict, List, Tuple
try:
# 使用 redis-py 4.2+ 的异步客户端(替代 aioredis
@ -76,7 +78,10 @@ class RedisCache:
self.username = username
self.password = password
self.redis: Optional[Redis] = None
self._memory_cache: Dict[str, Any] = {} # 降级到内存缓存
# 降级到内存缓存key -> (value, 写入时间戳),有最大条数与 TTL避免无限增长
self._memory_cache: Dict[str, Tuple[Any, float]] = {}
self._memory_cache_max_size = 200 # 最多 200 条
self._memory_cache_ttl_sec = 300 # 单条 5 分钟过期
self._connected = False
async def connect(self):
@ -205,10 +210,16 @@ class RedisCache:
if not self._connected:
await self.connect()
# 降级到内存缓存
# 降级到内存缓存(带 TTL 清理tuple 为 (value, ts),非 tuple 如 hash 的 dict 不做过期)
if key in self._memory_cache:
return self._memory_cache[key]
entry = self._memory_cache[key]
if isinstance(entry, tuple) and len(entry) == 2:
val, ts = entry
if time.time() - ts <= self._memory_cache_ttl_sec:
return val
del self._memory_cache[key]
else:
return entry
return None
async def set(self, key: str, value: Any, ttl: int = 3600):
@ -240,8 +251,14 @@ class RedisCache:
except:
pass
# 降级到内存缓存(不设置 TTL因为内存缓存不支持
self._memory_cache[key] = value
# 降级到内存缓存:限制条数 + TTL避免无限增长
def _ts(k):
v = self._memory_cache.get(k)
return v[1] if isinstance(v, tuple) and len(v) == 2 else 0
while len(self._memory_cache) >= self._memory_cache_max_size and self._memory_cache:
oldest_key = min(self._memory_cache.keys(), key=_ts)
del self._memory_cache[oldest_key]
self._memory_cache[key] = (value, time.time())
return False
async def get_int(self, key: str, default: int = 0) -> int:
@ -257,7 +274,9 @@ class RedisCache:
return int(default or 0)
if key in self._memory_cache:
try:
return int(self._memory_cache.get(key) or 0)
entry = self._memory_cache[key]
val = entry[0] if isinstance(entry, tuple) else entry
return int(val or 0)
except Exception:
return int(default or 0)
except Exception:
@ -288,14 +307,22 @@ class RedisCache:
except Exception as e:
logger.debug(f"Redis incr失败 {key}: {e}")
# 内存兜底(不做 TTL
# 内存兜底:限制条数 + TTL
cur = 0
try:
cur = int(self._memory_cache.get(key) or 0)
if key in self._memory_cache:
entry = self._memory_cache[key]
cur = int(entry[0] if isinstance(entry, tuple) else entry)
except Exception:
cur = 0
cur += inc
self._memory_cache[key] = cur
def _ts(k):
v = self._memory_cache.get(k)
return v[1] if isinstance(v, tuple) and len(v) == 2 else 0
while len(self._memory_cache) >= self._memory_cache_max_size and key not in self._memory_cache and self._memory_cache:
oldest_key = min(self._memory_cache.keys(), key=_ts)
del self._memory_cache[oldest_key]
self._memory_cache[key] = (cur, time.time())
return int(cur)
async def delete(self, key: str):

View File

@ -0,0 +1,40 @@
"""
Redis 缓存键与 TTL 统一配置所有键均带过期时间避免 Valkey/Redis 无限增长
全量缓存以 Redis 为主进程内存仅作降级或最小缓冲
"""
# 键前缀(便于排查与清理)
PREFIX_ATS = "ats:"
PREFIX_MARKET = "market:"
PREFIX_KLINES = "klines:"
PREFIX_TICKER = "ticker_24h:"
PREFIX_SYMBOL_INFO = "symbol_info:"
PREFIX_LISTEN_KEY = "listen_key:"
PREFIX_LEADER = "market_ws_leader"
# TTL- 所有写入 Redis 的键必须带 TTL禁止无过期
TTL_LEADER = 30
TTL_TICKER_24H = 120
TTL_BOOK_TICKER = 30
TTL_KLINE_STREAM = 600 # 10 分钟WS 实时写入的 K 线)
TTL_POSITIONS = 300 # 5 分钟(持仓缓存)
TTL_BALANCE = 300 # 5 分钟(余额缓存)
TTL_TICKER_SYMBOL = 30 # 单 symbol 24h 行情
TTL_SYMBOL_INFO = 3600 # 1 小时(交易对信息)
TTL_KLINES_REST = 1800 # REST 拉取的 K 线默认 30 分钟
TTL_KLINES_REST_OLD = 300 # 旧格式 klines:{s}:{i}:{limit} 默认 5 分钟
TTL_LISTEN_KEY = 55 * 60 # 55 分钟listenKey 缓存)
TTL_TREND_STATE = 3600
TTL_INDICATORS = 30
TTL_RECO_SNAPSHOT = 7200
TTL_RECO_ITEM = 3600
TTL_LOCK_RECO = 10
# K 线按周期 TTLREST 拉取后写入共享缓存)
TTL_KLINES_BY_INTERVAL = {
'1m': 60, '3m': 120, '5m': 180, '15m': 300, '30m': 600,
'1h': 900, '2h': 1800, '4h': 3600, '6h': 5400, '8h': 7200, '12h': 10800, '1d': 21600,
}
TTL_KLINES_OLD_BY_INTERVAL = {
'1m': 10, '3m': 20, '5m': 30, '15m': 60, '30m': 120,
'1h': 300, '2h': 600, '4h': 900, '6h': 1200, '8h': 1800, '12h': 2400, '1d': 3600,
}

View File

@ -25,17 +25,22 @@ def _get_stream_instance():
except Exception:
return None
def _get_balance_from_cache():
async def _get_balance_from_cache(client: Optional[BinanceClient] = None):
"""从缓存获取余额(优先 Redis降级到进程内存"""
try:
from .user_data_stream import get_balance_from_cache
return get_balance_from_cache()
redis_cache = getattr(client, "redis_cache", None) if client else None
return await get_balance_from_cache(redis_cache)
except Exception:
return None
def _get_positions_from_cache():
async def _get_positions_from_cache(client: Optional[BinanceClient] = None):
"""从缓存获取持仓(优先 Redis降级到进程内存"""
try:
from .user_data_stream import get_positions_from_cache
return get_positions_from_cache(float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0))
redis_cache = getattr(client, "redis_cache", None) if client else None
min_notional = float(getattr(config, "POSITION_MIN_NOTIONAL_USDT", 1.0) or 1.0)
return await get_positions_from_cache(min_notional, redis_cache)
except Exception:
return None
@ -71,8 +76,8 @@ class RiskManager:
try:
logger.info(f"检查 {symbol} 单笔仓位大小...")
# 获取账户余额(优先 WS 缓存
balance = _get_balance_from_cache() if _get_stream_instance() else None
# 获取账户余额(优先 WS 缓存Redis
balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None
if balance is None:
balance = await self.client.get_account_balance()
available_balance = balance.get('available', 0)
@ -162,8 +167,8 @@ class RiskManager:
是否通过检查
"""
try:
# 获取当前持仓(优先 WS 缓存
positions = _get_positions_from_cache() if _get_stream_instance() else None
# 获取当前持仓(优先 WS 缓存Redis
positions = await _get_positions_from_cache(self.client) if _get_stream_instance() else None
if positions is None:
positions = await self.client.get_open_positions()
@ -194,8 +199,8 @@ class RiskManager:
# 加上新仓位
total_with_new = total_margin_value + new_position_margin
# 获取账户余额(优先 WS 缓存
balance = _get_balance_from_cache() if _get_stream_instance() else None
# 获取账户余额(优先 WS 缓存Redis
balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None
if balance is None:
balance = await self.client.get_account_balance()
total_balance = balance.get('total', 0)
@ -430,8 +435,8 @@ class RiskManager:
try:
logger.info(f"开始计算 {symbol} 的仓位大小...")
# 获取账户余额(优先 WS 缓存
balance = _get_balance_from_cache() if _get_stream_instance() else None
# 获取账户余额(优先 WS 缓存Redis
balance = await _get_balance_from_cache(self.client) if _get_stream_instance() else None
if balance is None:
balance = await self.client.get_account_balance()
available_balance = balance.get('available', 0)
@ -818,7 +823,7 @@ class RiskManager:
return False
# 检查是否已有持仓 / 总持仓数量限制(优先 WS 缓存)
positions = _get_positions_from_cache() if _get_stream_instance() else None
positions = await _get_positions_from_cache(self.client) if _get_stream_instance() else None
if positions is None:
positions = await self.client.get_open_positions()
try:

View File

@ -39,39 +39,103 @@ def get_stream_instance() -> Optional["UserDataStream"]:
return _stream_instance
def seed_position_cache(positions: List[Dict]) -> None:
"""用 REST 全量持仓结果填充缓存(启动时调用一次,之后由 ACCOUNT_UPDATE 增量更新)。"""
try:
from .redis_ttl import TTL_POSITIONS
except ImportError:
TTL_POSITIONS = 300
try:
from .redis_ttl import TTL_BALANCE
except ImportError:
TTL_BALANCE = 300
async def seed_position_cache(positions: List[Dict], redis_cache: Any = None) -> None:
"""用 REST 全量持仓结果填充缓存。有 Redis 时只写 Redis、不占进程内存无 Redis 时写进程内存。"""
global _position_updates_cache, _position_cache_seeded
_position_updates_cache.clear()
_position_cache_seeded = True
positions_list = []
for pos in positions or []:
symbol = (pos.get("symbol") or "").strip()
if not symbol:
continue
amt = float(pos.get("positionAmt") or 0)
_position_updates_cache[symbol] = [{
"s": symbol,
"pa": amt,
"ep": str(pos.get("entryPrice") or "0"),
"up": str(pos.get("unRealizedProfit") or "0"),
"ps": pos.get("positionSide") or "BOTH",
}]
logger.debug(f"UserDataStream: 已填充持仓缓存 {len(_position_updates_cache)} 个 symbol")
if not redis_cache:
if symbol not in _position_updates_cache:
_position_updates_cache[symbol] = []
_position_updates_cache[symbol] = [{
"s": symbol,
"pa": amt,
"ep": str(pos.get("entryPrice") or "0"),
"up": str(pos.get("unRealizedProfit") or "0"),
"ps": pos.get("positionSide") or "BOTH",
}]
if amt != 0:
positions_list.append({
"symbol": symbol,
"positionAmt": amt,
"entryPrice": float(pos.get("entryPrice") or 0),
"markPrice": float(pos.get("markPrice") or 0),
"unRealizedProfit": float(pos.get("unRealizedProfit") or 0),
"leverage": int(pos.get("leverage") or 0),
})
if redis_cache:
_position_updates_cache.clear()
if positions_list:
try:
await redis_cache.set("ats:positions:cache", positions_list, ttl=TTL_POSITIONS)
except Exception as e:
logger.debug(f"写入持仓缓存到 Redis 失败: {e}")
logger.debug(f"UserDataStream: 已填充持仓缓存Redis=%s", bool(redis_cache))
def seed_balance_cache(balance: Dict[str, Any]) -> None:
"""用 REST 余额结果填充缓存(启动时调用一次,之后由 ACCOUNT_UPDATE 增量更新)。"""
async def seed_balance_cache(balance: Dict[str, Any], redis_cache: Any = None) -> None:
"""用 REST 余额结果填充缓存。有 Redis 时只写 Redis、不占进程内存无 Redis 时写进程内存"""
global _balance_updates_cache, _balance_cache_seeded
_balance_cache_seeded = True
if balance and isinstance(balance, dict):
wb = balance.get("walletBalance") or balance.get("total") or 0
av = balance.get("availableBalance") or balance.get("available") or wb
_balance_updates_cache["USDT"] = {"wb": str(wb), "cw": str(av), "bc": "0"}
logger.debug("UserDataStream: 已填充余额缓存 (USDT)")
balance_data = {"wb": str(wb), "cw": str(av), "bc": "0"}
if redis_cache:
try:
await redis_cache.set("ats:balance:cache:USDT", balance_data, ttl=TTL_BALANCE)
except Exception as e:
logger.debug(f"写入余额缓存到 Redis 失败: {e}")
else:
_balance_updates_cache["USDT"] = balance_data
logger.debug("UserDataStream: 已填充余额缓存 (USDT, Redis=%s)", bool(redis_cache))
def get_positions_from_cache(min_notional: float = 1.0) -> Optional[List[Dict]]:
"""将持仓缓存转为与 REST get_open_positions 一致的列表格式;未播种时返回 None业务应走 REST"""
async def get_positions_from_cache(min_notional: float = 1.0, redis_cache: Any = None) -> Optional[List[Dict]]:
"""
将持仓缓存转为与 REST get_open_positions 一致的列表格式未播种时返回 None业务应走 REST
内存优化优先从 Redis 读取减少进程内存占用
"""
# 优先从 Redis 读取(多进程共享)
if redis_cache:
try:
redis_key = "ats:positions:cache"
cached = await redis_cache.get(redis_key)
if cached and isinstance(cached, list):
# 过滤最小名义价值
filtered = []
for pos in cached:
try:
pa = float(pos.get("positionAmt") or 0)
ep = float(pos.get("entryPrice") or 0)
if pa == 0:
continue
if min_notional > 0 and abs(pa) * ep < min_notional:
continue
filtered.append(pos)
except Exception:
continue
if filtered:
return filtered
except Exception as e:
logger.debug(f"从 Redis 读取持仓缓存失败: {e}")
# 降级到进程内存缓存
if not _position_cache_seeded:
return None
out = []
@ -97,8 +161,27 @@ def get_positions_from_cache(min_notional: float = 1.0) -> Optional[List[Dict]]:
return out
def get_balance_from_cache() -> Optional[Dict[str, Any]]:
"""从缓存返回 USDT 余额(与 REST get_account_balance 结构兼容);未播种或无缓存时返回 None。"""
async def get_balance_from_cache(redis_cache: Any = None) -> Optional[Dict[str, Any]]:
"""
从缓存返回 USDT 余额 REST get_account_balance 结构兼容未播种或无缓存时返回 None
内存优化优先从 Redis 读取减少进程内存占用
"""
# 优先从 Redis 读取(多进程共享)
if redis_cache:
try:
redis_key = "ats:balance:cache:USDT"
cached = await redis_cache.get(redis_key)
if cached and isinstance(cached, dict):
try:
wb = float(cached.get("wb") or cached.get("total") or 0)
cw = float(cached.get("cw") or cached.get("available") or wb)
return {"ok": True, "available": cw, "total": wb, "margin": wb}
except (TypeError, ValueError):
pass
except Exception as e:
logger.debug(f"从 Redis 读取余额缓存失败: {e}")
# 降级到进程内存缓存
if not _balance_cache_seeded:
return None
u = _balance_updates_cache.get("USDT")
@ -455,29 +538,73 @@ class UserDataStream:
logger.warning(f"UserDataStream: set_exit_order_id_for_open_trade 失败 {ex}")
def _on_account_update(self, a: Dict):
# 文档: a.B = 余额数组,每项 a=资产, wb=钱包余额, cw=除逐仓外的钱包余额, bc=余额变化量
# 文档: a.P = 持仓信息数组,每项 s=symbol, pa=仓位, ep=入仓价, ps=LONG/SHORT/BOTH 等
# 文档: a.B = 余额数组a.P = 持仓信息数组。有 Redis 时只写 Redis、不写进程内存。
global _position_updates_cache, _balance_updates_cache
redis_cache = getattr(self.client, "redis_cache", None)
B = a.get("B")
if isinstance(B, list) and B:
for b in B:
asset = (b.get("a") or "").strip()
if asset:
_balance_updates_cache[asset] = {
balance_data = {
"wb": b.get("wb"),
"cw": b.get("cw"),
"bc": b.get("bc"),
}
if redis_cache:
if asset == "USDT":
asyncio.create_task(self._write_balance_to_redis(asset, balance_data))
else:
_balance_updates_cache[asset] = balance_data
logger.debug(f"UserDataStream: ACCOUNT_UPDATE 余额资产数 {len(B)}")
P = a.get("P")
if isinstance(P, list) and P:
positions_list = []
for p in P:
s = (p.get("s") or "").strip()
if s:
if s not in _position_updates_cache:
_position_updates_cache[s] = []
_position_updates_cache[s] = [p]
if not redis_cache:
if s not in _position_updates_cache:
_position_updates_cache[s] = []
_position_updates_cache[s] = [p]
try:
pa = float(p.get("pa") or 0)
ep = float(p.get("ep") or 0)
if pa != 0:
positions_list.append({
"symbol": s,
"positionAmt": pa,
"entryPrice": ep,
"markPrice": float(p.get("markPrice") or 0),
"unRealizedProfit": float(p.get("up") or 0),
"leverage": int(p.get("leverage") or 0),
})
except Exception:
pass
if redis_cache and positions_list:
asyncio.create_task(self._write_positions_to_redis(positions_list))
logger.debug(f"UserDataStream: ACCOUNT_UPDATE 持仓数 {len(P)}")
async def _write_balance_to_redis(self, asset: str, balance_data: Dict):
"""写入余额缓存到 Redis带 TTL避免无限增长"""
try:
redis_cache = getattr(self.client, "redis_cache", None)
if redis_cache:
redis_key = f"ats:balance:cache:{asset}"
await redis_cache.set(redis_key, balance_data, ttl=TTL_BALANCE)
except Exception as e:
logger.debug(f"写入余额缓存到 Redis 失败: {e}")
async def _write_positions_to_redis(self, positions_list: List[Dict]):
"""写入持仓缓存到 Redis带 TTL避免无限增长"""
try:
redis_cache = getattr(self.client, "redis_cache", None)
if redis_cache:
redis_key = "ats:positions:cache"
await redis_cache.set(redis_key, positions_list, ttl=TTL_POSITIONS)
except Exception as e:
logger.debug(f"写入持仓缓存到 Redis 失败: {e}")
def _on_algo_update(self, o: Dict):
# 条件单交易更新推送X=TRIGGERED/FINISHED 且 ai=触发后普通订单 id 时,回写 open 记录的 exit_order_id