""" 数据库连接管理 """ import pymysql from contextlib import contextmanager import os import logging from pathlib import Path from sqlalchemy import create_engine, pool logger = logging.getLogger(__name__) # 尝试加载.env文件 try: from dotenv import load_dotenv # 从backend目录或项目根目录查找.env文件 backend_dir = Path(__file__).parent.parent project_root = backend_dir.parent # 按优先级查找.env文件 env_files = [ backend_dir / '.env', # backend/.env (优先) project_root / '.env', # 项目根目录/.env ] loaded = False for env_file in env_files: if env_file.exists(): load_dotenv(env_file, override=True) logger.info(f"已加载环境变量文件: {env_file}") loaded = True break if not loaded: # 如果都不存在,尝试自动查找(不报错) load_dotenv(project_root / '.env', override=False) except ImportError: # 如果没有安装python-dotenv,跳过 logger.debug("python-dotenv未安装,跳过.env文件加载") except Exception as e: logger.warning(f"加载.env文件失败: {e}") class Database: """数据库连接类(使用SQLAlchemy连接池)""" _engine = None def __init__(self): self.host = os.getenv('DB_HOST', 'localhost') self.port = int(os.getenv('DB_PORT', 3306)) self.user = os.getenv('DB_USER', 'root') self.password = os.getenv('DB_PASSWORD', '') self.database = os.getenv('DB_NAME', 'auto_trade_sys_new') # 记录配置信息(不显示密码) logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}") # 初始化连接池 self._init_engine() def _init_engine(self): """初始化SQLAlchemy引擎和连接池""" if Database._engine is None: # 构建数据库URL # 注意:密码中如果有特殊字符需要转义,这里简单处理 from urllib.parse import quote_plus encoded_password = quote_plus(self.password) db_url = f"mysql+pymysql://{self.user}:{encoded_password}@{self.host}:{self.port}/{self.database}?charset=utf8mb4" try: Database._engine = create_engine( db_url, pool_size=20, # 基础连接池大小 max_overflow=30, # 最大溢出连接数 pool_recycle=3600, # 连接回收时间(秒) pool_timeout=30, # 获取连接超时时间(秒) pool_pre_ping=True, # 预检测连接是否可用 connect_args={ # 'cursorclass': pymysql.cursors.DictCursor, # Removed to prevent KeyError: 0 in SQLAlchemy init 'autocommit': False } ) logger.info("数据库连接池初始化成功") except Exception as e: logger.error(f"数据库连接池初始化失败: {e}") raise @contextmanager def get_connection(self): """获取数据库连接(从连接池)""" conn = None try: # 获取原始pymysql连接 conn = Database._engine.raw_connection() # Explicitly set cursor class to DictCursor since we removed it from create_engine # We need to set it on the underlying DBAPI connection try: if hasattr(conn, 'driver_connection'): # SQLAlchemy 2.0+ conn.driver_connection.cursorclass = pymysql.cursors.DictCursor elif hasattr(conn, 'connection'): # Older SQLAlchemy conn.connection.cursorclass = pymysql.cursors.DictCursor else: # Fallback conn.cursorclass = pymysql.cursors.DictCursor except Exception as e: logger.warning(f"设置DictCursor失败: {e}") # 设置时区为北京时间(UTC+8) # 注意:raw_connection可能不自动应用connect_args中的autocommit,需确认 # SQLAlchemy的raw_connection通常返回DBAPI连接,autocommit行为取决于驱动 # 这里显式关闭autocommit以保持兼容性 try: conn.autocommit(False) except AttributeError: # 某些旧版本pymysql或wrapper可能不支持方法调用,尝试属性赋值 pass with conn.cursor() as cursor: cursor.execute("SET time_zone = '+08:00'") # 注意:不在这里commit,除非是只读操作。调用者负责commit/rollback # 但原代码在yield前commit了时区设置? # 原代码:cursor.execute(...); conn.commit(); yield conn # SET time_zone 不需要 commit,但为了保险起见保留原行为 conn.commit() yield conn except Exception as e: if conn: try: conn.rollback() except: pass logger.error(f"数据库连接错误: {e}") raise finally: if conn: conn.close() # 归还给连接池 def execute_query(self, query, params=None): """执行查询,返回所有结果""" with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) conn.commit() return cursor.fetchall() def execute_one(self, query, params=None): """执行查询,返回单条结果""" with self.get_connection() as conn: with conn.cursor() as cursor: cursor.execute(query, params) conn.commit() return cursor.fetchone() def execute_update(self, query, params=None): """执行更新,返回影响行数""" try: with self.get_connection() as conn: with conn.cursor() as cursor: affected = cursor.execute(query, params) conn.commit() return affected except Exception as e: # 重新抛出异常,让调用者处理(如 update_exit 中的异常处理) # 不要在这里记录为"数据库连接错误",因为这可能是业务逻辑错误(如唯一约束冲突) raise def execute_many(self, query, params_list): """批量执行""" with self.get_connection() as conn: with conn.cursor() as cursor: affected = cursor.executemany(query, params_list) conn.commit() return affected # 全局数据库实例 db = Database()