我拿 AI 改坏枪代码!
This commit is contained in:
@ -1,64 +1,127 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from typing import List, Dict, Any, Optional
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""超级无敌神奇的数据库!"""
|
||||
|
||||
@classmethod
|
||||
def query(cls, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
|
||||
"""异步数据库管理器"""
|
||||
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None):
|
||||
"""
|
||||
初始化数据库管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,支持str和Path类型
|
||||
"""
|
||||
if db_path is None:
|
||||
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
|
||||
else:
|
||||
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||
|
||||
# 连接池
|
||||
self._connection_pool = []
|
||||
self._pool_size = 5
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _get_connection(self) -> aiosqlite.Connection:
|
||||
"""从连接池获取连接"""
|
||||
async with self._lock:
|
||||
if self._connection_pool:
|
||||
return self._connection_pool.pop()
|
||||
|
||||
# 如果连接池为空,创建新连接
|
||||
conn = await aiosqlite.connect(self.db_path)
|
||||
await conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
|
||||
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
|
||||
"""将连接返回到连接池"""
|
||||
async with self._lock:
|
||||
if len(self._connection_pool) < self._pool_size:
|
||||
self._connection_pool.append(conn)
|
||||
else:
|
||||
await conn.close()
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""执行查询语句并返回结果"""
|
||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, params or ())
|
||||
columns = [description[0] for description in cursor.description]
|
||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
cursor.close()
|
||||
conn.close()
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def query_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = await conn.execute(query, params or ())
|
||||
columns = [description[0] for description in cursor.description]
|
||||
rows = await cursor.fetchall()
|
||||
results = [dict(zip(columns, row)) for row in rows]
|
||||
await cursor.close()
|
||||
return results
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def query_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从 SQL 文件中读取查询语句并执行"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
query = f.read()
|
||||
return cls.query(query, params)
|
||||
return await self.query(query, params)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, command: str, params: Optional[tuple] = None) -> None:
|
||||
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
|
||||
"""执行非查询语句"""
|
||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(command, params or ())
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.execute(command, params or ())
|
||||
await conn.commit()
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
@classmethod
|
||||
def execute_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> None:
|
||||
async def execute_script(self, script: str) -> None:
|
||||
"""执行SQL脚本"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.executescript(script)
|
||||
await conn.commit()
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def execute_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取非查询语句并执行"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
command = f.read()
|
||||
# 按照需要执行多条语句
|
||||
commands = command.split(';')
|
||||
for cmd in commands:
|
||||
cmd = cmd.strip()
|
||||
if cmd:
|
||||
cls.execute(cmd, params)
|
||||
|
||||
@classmethod
|
||||
def execute_many(cls, command: str, seq_of_params: List[tuple]) -> None:
|
||||
"""执行多条非查询语句"""
|
||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
||||
cursor = conn.cursor()
|
||||
cursor.executemany(command, seq_of_params)
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
conn.close()
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
script = f.read()
|
||||
# 如果有参数,使用execute方法而不是execute_script
|
||||
if params:
|
||||
await self.execute(script, params)
|
||||
else:
|
||||
await self.execute_script(script)
|
||||
|
||||
@classmethod
|
||||
def execute_many_values_by_sql_file(cls, file_path: str, seq_of_params: List[tuple]) -> None:
|
||||
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
|
||||
"""执行多条非查询语句"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.executemany(command, seq_of_params)
|
||||
await conn.commit()
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def execute_many_values_by_sql_file(
|
||||
self, file_path: Union[str, Path], seq_of_params: List[tuple]
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
command = f.read()
|
||||
cls.execute_many(command, seq_of_params)
|
||||
await self.execute_many(command, seq_of_params)
|
||||
|
||||
async def close_all_connections(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
async with self._lock:
|
||||
for conn in self._connection_pool:
|
||||
await conn.close()
|
||||
self._connection_pool.clear()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user