auto_trade_sys/backend/database/connection.py
薇薇安 922a8f3820 1
2026-02-04 13:45:30 +08:00

183 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接管理
"""
import pymysql
from contextlib import contextmanager
import os
import logging
from pathlib import Path
from sqlalchemy import create_engine, pool
logger = logging.getLogger(__name__)
# 尝试加载.env文件
try:
from dotenv import load_dotenv
# 从backend目录或项目根目录查找.env文件
backend_dir = Path(__file__).parent.parent
project_root = backend_dir.parent
# 按优先级查找.env文件
env_files = [
backend_dir / '.env', # backend/.env (优先)
project_root / '.env', # 项目根目录/.env
]
loaded = False
for env_file in env_files:
if env_file.exists():
load_dotenv(env_file, override=True)
logger.info(f"已加载环境变量文件: {env_file}")
loaded = True
break
if not loaded:
# 如果都不存在,尝试自动查找(不报错)
load_dotenv(project_root / '.env', override=False)
except ImportError:
# 如果没有安装python-dotenv跳过
logger.debug("python-dotenv未安装跳过.env文件加载")
except Exception as e:
logger.warning(f"加载.env文件失败: {e}")
class Database:
"""数据库连接类使用SQLAlchemy连接池"""
_engine = None
def __init__(self):
self.host = os.getenv('DB_HOST', 'localhost')
self.port = int(os.getenv('DB_PORT', 3306))
self.user = os.getenv('DB_USER', 'root')
self.password = os.getenv('DB_PASSWORD', '')
self.database = os.getenv('DB_NAME', 'auto_trade_sys')
# 记录配置信息(不显示密码)
logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}")
# 初始化连接池
self._init_engine()
def _init_engine(self):
"""初始化SQLAlchemy引擎和连接池"""
if Database._engine is None:
# 构建数据库URL
# 注意:密码中如果有特殊字符需要转义,这里简单处理
from urllib.parse import quote_plus
encoded_password = quote_plus(self.password)
db_url = f"mysql+pymysql://{self.user}:{encoded_password}@{self.host}:{self.port}/{self.database}?charset=utf8mb4"
try:
Database._engine = create_engine(
db_url,
pool_size=20, # 基础连接池大小
max_overflow=30, # 最大溢出连接数
pool_recycle=3600, # 连接回收时间(秒)
pool_timeout=30, # 获取连接超时时间(秒)
pool_pre_ping=True, # 预检测连接是否可用
connect_args={
# 'cursorclass': pymysql.cursors.DictCursor, # Removed to prevent KeyError: 0 in SQLAlchemy init
'autocommit': False
}
)
logger.info("数据库连接池初始化成功")
except Exception as e:
logger.error(f"数据库连接池初始化失败: {e}")
raise
@contextmanager
def get_connection(self):
"""获取数据库连接(从连接池)"""
conn = None
try:
# 获取原始pymysql连接
conn = Database._engine.raw_connection()
# Explicitly set cursor class to DictCursor since we removed it from create_engine
# We need to set it on the underlying DBAPI connection
try:
if hasattr(conn, 'driver_connection'):
# SQLAlchemy 2.0+
conn.driver_connection.cursorclass = pymysql.cursors.DictCursor
elif hasattr(conn, 'connection'):
# Older SQLAlchemy
conn.connection.cursorclass = pymysql.cursors.DictCursor
else:
# Fallback
conn.cursorclass = pymysql.cursors.DictCursor
except Exception as e:
logger.warning(f"设置DictCursor失败: {e}")
# 设置时区为北京时间UTC+8
# 注意raw_connection可能不自动应用connect_args中的autocommit需确认
# SQLAlchemy的raw_connection通常返回DBAPI连接autocommit行为取决于驱动
# 这里显式关闭autocommit以保持兼容性
try:
conn.autocommit(False)
except AttributeError:
# 某些旧版本pymysql或wrapper可能不支持方法调用尝试属性赋值
pass
with conn.cursor() as cursor:
cursor.execute("SET time_zone = '+08:00'")
# 注意不在这里commit除非是只读操作。调用者负责commit/rollback
# 但原代码在yield前commit了时区设置?
# 原代码cursor.execute(...); conn.commit(); yield conn
# SET time_zone 不需要 commit但为了保险起见保留原行为
conn.commit()
yield conn
except Exception as e:
if conn:
try:
conn.rollback()
except:
pass
logger.error(f"数据库连接错误: {e}")
raise
finally:
if conn:
conn.close() # 归还给连接池
def execute_query(self, query, params=None):
"""执行查询,返回所有结果"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params)
conn.commit()
return cursor.fetchall()
def execute_one(self, query, params=None):
"""执行查询,返回单条结果"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params)
conn.commit()
return cursor.fetchone()
def execute_update(self, query, params=None):
"""执行更新,返回影响行数"""
try:
with self.get_connection() as conn:
with conn.cursor() as cursor:
affected = cursor.execute(query, params)
conn.commit()
return affected
except Exception as e:
# 重新抛出异常,让调用者处理(如 update_exit 中的异常处理)
# 不要在这里记录为"数据库连接错误",因为这可能是业务逻辑错误(如唯一约束冲突)
raise
def execute_many(self, query, params_list):
"""批量执行"""
with self.get_connection() as conn:
with conn.cursor() as cursor:
affected = cursor.executemany(query, params_list)
conn.commit()
return affected
# 全局数据库实例
db = Database()