183 lines
6.9 KiB
Python
183 lines
6.9 KiB
Python
"""
|
||
数据库连接管理
|
||
"""
|
||
import pymysql
|
||
from contextlib import contextmanager
|
||
import os
|
||
import logging
|
||
from pathlib import Path
|
||
from sqlalchemy import create_engine, pool
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 尝试加载.env文件
|
||
try:
|
||
from dotenv import load_dotenv
|
||
# 从backend目录或项目根目录查找.env文件
|
||
backend_dir = Path(__file__).parent.parent
|
||
project_root = backend_dir.parent
|
||
|
||
# 按优先级查找.env文件
|
||
env_files = [
|
||
backend_dir / '.env', # backend/.env (优先)
|
||
project_root / '.env', # 项目根目录/.env
|
||
]
|
||
|
||
loaded = False
|
||
for env_file in env_files:
|
||
if env_file.exists():
|
||
load_dotenv(env_file, override=True)
|
||
logger.info(f"已加载环境变量文件: {env_file}")
|
||
loaded = True
|
||
break
|
||
|
||
if not loaded:
|
||
# 如果都不存在,尝试自动查找(不报错)
|
||
load_dotenv(project_root / '.env', override=False)
|
||
except ImportError:
|
||
# 如果没有安装python-dotenv,跳过
|
||
logger.debug("python-dotenv未安装,跳过.env文件加载")
|
||
except Exception as e:
|
||
logger.warning(f"加载.env文件失败: {e}")
|
||
|
||
|
||
class Database:
|
||
"""数据库连接类(使用SQLAlchemy连接池)"""
|
||
|
||
_engine = None
|
||
|
||
def __init__(self):
|
||
self.host = os.getenv('DB_HOST', 'localhost')
|
||
self.port = int(os.getenv('DB_PORT', 3306))
|
||
self.user = os.getenv('DB_USER', 'root')
|
||
self.password = os.getenv('DB_PASSWORD', '')
|
||
self.database = os.getenv('DB_NAME', 'auto_trade_sys_new')
|
||
|
||
# 记录配置信息(不显示密码)
|
||
logger.debug(f"数据库配置: host={self.host}, port={self.port}, user={self.user}, database={self.database}")
|
||
|
||
# 初始化连接池
|
||
self._init_engine()
|
||
|
||
def _init_engine(self):
|
||
"""初始化SQLAlchemy引擎和连接池"""
|
||
if Database._engine is None:
|
||
# 构建数据库URL
|
||
# 注意:密码中如果有特殊字符需要转义,这里简单处理
|
||
from urllib.parse import quote_plus
|
||
encoded_password = quote_plus(self.password)
|
||
db_url = f"mysql+pymysql://{self.user}:{encoded_password}@{self.host}:{self.port}/{self.database}?charset=utf8mb4"
|
||
|
||
try:
|
||
Database._engine = create_engine(
|
||
db_url,
|
||
pool_size=20, # 基础连接池大小
|
||
max_overflow=30, # 最大溢出连接数
|
||
pool_recycle=3600, # 连接回收时间(秒)
|
||
pool_timeout=30, # 获取连接超时时间(秒)
|
||
pool_pre_ping=True, # 预检测连接是否可用
|
||
connect_args={
|
||
# 'cursorclass': pymysql.cursors.DictCursor, # Removed to prevent KeyError: 0 in SQLAlchemy init
|
||
'autocommit': False
|
||
}
|
||
)
|
||
logger.info("数据库连接池初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"数据库连接池初始化失败: {e}")
|
||
raise
|
||
|
||
@contextmanager
|
||
def get_connection(self):
|
||
"""获取数据库连接(从连接池)"""
|
||
conn = None
|
||
try:
|
||
# 获取原始pymysql连接
|
||
conn = Database._engine.raw_connection()
|
||
|
||
# Explicitly set cursor class to DictCursor since we removed it from create_engine
|
||
# We need to set it on the underlying DBAPI connection
|
||
try:
|
||
if hasattr(conn, 'driver_connection'):
|
||
# SQLAlchemy 2.0+
|
||
conn.driver_connection.cursorclass = pymysql.cursors.DictCursor
|
||
elif hasattr(conn, 'connection'):
|
||
# Older SQLAlchemy
|
||
conn.connection.cursorclass = pymysql.cursors.DictCursor
|
||
else:
|
||
# Fallback
|
||
conn.cursorclass = pymysql.cursors.DictCursor
|
||
except Exception as e:
|
||
logger.warning(f"设置DictCursor失败: {e}")
|
||
|
||
# 设置时区为北京时间(UTC+8)
|
||
# 注意:raw_connection可能不自动应用connect_args中的autocommit,需确认
|
||
# SQLAlchemy的raw_connection通常返回DBAPI连接,autocommit行为取决于驱动
|
||
# 这里显式关闭autocommit以保持兼容性
|
||
try:
|
||
conn.autocommit(False)
|
||
except AttributeError:
|
||
# 某些旧版本pymysql或wrapper可能不支持方法调用,尝试属性赋值
|
||
pass
|
||
|
||
with conn.cursor() as cursor:
|
||
cursor.execute("SET time_zone = '+08:00'")
|
||
# 注意:不在这里commit,除非是只读操作。调用者负责commit/rollback
|
||
# 但原代码在yield前commit了时区设置?
|
||
# 原代码:cursor.execute(...); conn.commit(); yield conn
|
||
# SET time_zone 不需要 commit,但为了保险起见保留原行为
|
||
conn.commit()
|
||
|
||
yield conn
|
||
except Exception as e:
|
||
if conn:
|
||
try:
|
||
conn.rollback()
|
||
except:
|
||
pass
|
||
logger.error(f"数据库连接错误: {e}")
|
||
raise
|
||
finally:
|
||
if conn:
|
||
conn.close() # 归还给连接池
|
||
|
||
def execute_query(self, query, params=None):
|
||
"""执行查询,返回所有结果"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(query, params)
|
||
conn.commit()
|
||
return cursor.fetchall()
|
||
|
||
def execute_one(self, query, params=None):
|
||
"""执行查询,返回单条结果"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
cursor.execute(query, params)
|
||
conn.commit()
|
||
return cursor.fetchone()
|
||
|
||
def execute_update(self, query, params=None):
|
||
"""执行更新,返回影响行数"""
|
||
try:
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
affected = cursor.execute(query, params)
|
||
conn.commit()
|
||
return affected
|
||
except Exception as e:
|
||
# 重新抛出异常,让调用者处理(如 update_exit 中的异常处理)
|
||
# 不要在这里记录为"数据库连接错误",因为这可能是业务逻辑错误(如唯一约束冲突)
|
||
raise
|
||
|
||
def execute_many(self, query, params_list):
|
||
"""批量执行"""
|
||
with self.get_connection() as conn:
|
||
with conn.cursor() as cursor:
|
||
affected = cursor.executemany(query, params_list)
|
||
conn.commit()
|
||
return affected
|
||
|
||
|
||
# 全局数据库实例
|
||
db = Database()
|