65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
from typing import Any
|
|
import openai
|
|
|
|
from loguru import logger
|
|
from openai.types.chat import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageParam
|
|
from pydantic import BaseModel, Field
|
|
|
|
from konabot.common.path import CONFIG_PATH
|
|
|
|
LLM_CONFIG_PATH = CONFIG_PATH / 'llm.json'
|
|
|
|
if not LLM_CONFIG_PATH.exists():
|
|
LLM_CONFIG_PATH.write_text("{}")
|
|
|
|
|
|
class LLMInfo(BaseModel):
|
|
base_url: str
|
|
api_key: str
|
|
model_name: str
|
|
|
|
def get_openai_client(self):
|
|
return openai.AsyncClient(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url,
|
|
)
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[ChatCompletionMessageParam],
|
|
timeout: float | None = 30.0,
|
|
max_tokens: int | None = None,
|
|
**kwargs: Any,
|
|
) -> ChatCompletionMessage:
|
|
logger.info(f"调用 LLM: BASE_URL={self.base_url} MODEL_NAME={self.model_name}")
|
|
completion: ChatCompletion = await self.get_openai_client().chat.completions.create(
|
|
messages=messages,
|
|
model=self.model_name,
|
|
max_tokens=max_tokens,
|
|
timeout=timeout,
|
|
stream=False,
|
|
**kwargs,
|
|
)
|
|
choice = completion.choices[0]
|
|
logger.info(
|
|
f"调用 LLM 完成: BASE_URL={self.base_url} MODEL_NAME={self.model_name} REASON={choice.finish_reason}"
|
|
)
|
|
return choice.message
|
|
|
|
|
|
class LLMConfig(BaseModel):
|
|
llms: dict[str, LLMInfo] = Field(default_factory=dict)
|
|
default_llm: str = "Qwen2.5-7B-Instruct"
|
|
|
|
|
|
llm_config = LLMConfig.model_validate_json(LLM_CONFIG_PATH.read_text())
|
|
|
|
|
|
def get_llm(llm_model: str | None = None):
|
|
if llm_model is None:
|
|
llm_model = llm_config.default_llm
|
|
if llm_model not in llm_config.llms:
|
|
raise NotImplementedError("LLM 未配置,该功能无法使用")
|
|
return llm_config.llms[llm_model]
|
|
|