93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
import asyncio
|
||
import os
|
||
import tempfile
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
|
||
from konabot.common.database import DatabaseManager
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_database_manager():
|
||
"""测试数据库管理器的基本功能"""
|
||
# 创建临时数据库文件
|
||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||
db_path = tmp_file.name
|
||
|
||
try:
|
||
# 初始化数据库管理器
|
||
db_manager = DatabaseManager(db_path)
|
||
|
||
# 创建测试表
|
||
create_table_sql = """
|
||
CREATE TABLE IF NOT EXISTS test_users (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
email TEXT UNIQUE
|
||
);
|
||
"""
|
||
await db_manager.execute(create_table_sql)
|
||
|
||
# 插入测试数据
|
||
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
||
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
||
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
||
|
||
# 查询数据
|
||
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
||
results = await db_manager.query(select_sql, ("张三",))
|
||
assert len(results) == 1
|
||
assert results[0]["name"] == "张三"
|
||
assert results[0]["email"] == "zhangsan@example.com"
|
||
|
||
# 测试使用Path对象
|
||
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
||
|
||
# 关闭所有连接
|
||
await db_manager.close_all_connections()
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(db_path):
|
||
os.unlink(db_path)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_script():
|
||
"""测试执行SQL脚本功能"""
|
||
# 创建临时数据库文件
|
||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||
db_path = tmp_file.name
|
||
|
||
try:
|
||
# 初始化数据库管理器
|
||
db_manager = DatabaseManager(db_path)
|
||
|
||
# 创建测试表的脚本
|
||
script = """
|
||
CREATE TABLE IF NOT EXISTS test_products (
|
||
id INTEGER PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
price REAL
|
||
);
|
||
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
||
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
||
"""
|
||
|
||
await db_manager.execute_script(script)
|
||
|
||
# 查询数据
|
||
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
||
assert len(results) == 2
|
||
assert results[0]["name"] == "苹果"
|
||
assert results[1]["name"] == "香蕉"
|
||
|
||
# 关闭所有连接
|
||
await db_manager.close_all_connections()
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
if os.path.exists(db_path):
|
||
os.unlink(db_path) |