120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
from typing import Annotated
|
|
import nonebot
|
|
from nonebot.adapters import Event
|
|
from nonebot.params import Depends
|
|
from nonebot.rule import Rule
|
|
|
|
from konabot.common.appcontext import after_init
|
|
from konabot.common.database import DatabaseManager
|
|
from konabot.common.pager import PagerQuery
|
|
from konabot.common.path import DATA_PATH
|
|
from konabot.common.permsys.entity import PermEntity, get_entity_chain
|
|
from konabot.common.permsys.migrates import execute_migration
|
|
from konabot.common.permsys.repo import PermRepo
|
|
|
|
|
|
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
|
_default_allow_permissions: set[str] = set()
|
|
|
|
|
|
_EntityLike = Event | PermEntity | list[PermEntity]
|
|
|
|
|
|
async def _to_entity_chain(el: _EntityLike):
|
|
if isinstance(el, Event):
|
|
return await get_entity_chain(el) # pragma: no cover
|
|
if isinstance(el, PermEntity):
|
|
return [el]
|
|
return el
|
|
|
|
|
|
class PermManager:
|
|
def __init__(self, db: DatabaseManager) -> None:
|
|
self.db = db
|
|
|
|
async def check_has_permission_info(self, entities: _EntityLike, key: str):
|
|
entities = await _to_entity_chain(entities)
|
|
key = key.removesuffix("*").removesuffix(".")
|
|
key_split = key.split(".")
|
|
key_split = [s for s in key_split if len(s) > 0]
|
|
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
|
|
"*"
|
|
]
|
|
|
|
async with self.db.get_conn() as conn:
|
|
repo = PermRepo(conn)
|
|
data = await repo.get_perm_info_batch(entities, keys)
|
|
for entity in entities:
|
|
for k in keys:
|
|
p = data.get((entity, k))
|
|
if p is not None:
|
|
return (entity, k, p)
|
|
return None
|
|
|
|
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
|
|
res = await self.check_has_permission_info(entities, key)
|
|
if res is None:
|
|
return False
|
|
return res[2]
|
|
|
|
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
|
|
async with self.db.get_conn() as conn:
|
|
repo = PermRepo(conn)
|
|
await repo.update_perm_info(entity, key, perm)
|
|
|
|
async def list_permission(self, entities: _EntityLike, query: PagerQuery):
|
|
entities = await _to_entity_chain(entities)
|
|
async with self.db.get_conn() as conn:
|
|
repo = PermRepo(conn)
|
|
return await repo.list_perm_info_batch(entities, query)
|
|
|
|
|
|
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: # pragma: no cover
|
|
if _db is None:
|
|
_db = db
|
|
return PermManager(_db)
|
|
|
|
|
|
@after_init
|
|
def create_startup(): # pragma: no cover
|
|
from konabot.common.nb.is_admin import cfg
|
|
|
|
driver = nonebot.get_driver()
|
|
|
|
@driver.on_startup
|
|
async def _():
|
|
async with db.get_conn() as conn:
|
|
await execute_migration(conn)
|
|
pm = perm_manager(db)
|
|
for account in cfg.admin_qq_account:
|
|
# ^ 这里的是超级管理员!!用环境变量定义的。
|
|
# 咕嘿嘿嘿!!!夺取全部权限!!!
|
|
await pm.update_permission(
|
|
PermEntity("ob11", "user", str(account)), "*", True
|
|
)
|
|
for key in _default_allow_permissions:
|
|
await pm.update_permission(
|
|
PermEntity("sys", "global", "global"), key, True
|
|
)
|
|
|
|
@driver.on_shutdown
|
|
async def _():
|
|
try:
|
|
await db.close_all_connections()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
DepPermManager = Annotated[PermManager, Depends(perm_manager)]
|
|
|
|
|
|
def register_default_allow_permission(key: str):
|
|
_default_allow_permissions.add(key)
|
|
|
|
|
|
def require_permission(perm: str) -> Rule: # pragma: no cover
|
|
async def check_permission(event: Event, pm: DepPermManager) -> bool:
|
|
return await pm.check_has_permission(event, perm)
|
|
|
|
return Rule(check_permission)
|