import logging
from typing import Union, Generator, Optional, List
from dify_plugin.provider_kits.llm import LargeLanguageModel # Base class
from dify_plugin.provider_kits.llm import LLMResult, LLMResultChunk, LLMUsage # Result classes
from dify_plugin.provider_kits.llm import PromptMessage, PromptMessageTool # Message classes
from dify_plugin.errors.provider_error import InvokeError, InvokeAuthorizationError # Error classes
logger = logging.getLogger(__name__)
class MyProviderLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict, prompt_messages: List[PromptMessage],
model_parameters: dict, tools: Optional[List[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
"""
Core method for invoking the model API.
Parameters:
model: The model identifier to call
credentials: Authentication credentials
prompt_messages: List of messages to send
model_parameters: Parameters like temperature, max_tokens
tools: Optional tool definitions for function calling
stop: Optional list of stop sequences
stream: Whether to stream responses (True) or return complete response (False)
user: Optional user identifier for API tracking
Returns:
If stream=True: Generator yielding LLMResultChunk objects
If stream=False: Complete LLMResult object
"""
# Prepare API request parameters
api_params = self._prepare_api_params(
credentials, model_parameters, prompt_messages, tools, stop
)
try:
# Call appropriate helper method based on streaming preference
if stream:
return self._invoke_stream(model, api_params, user)
else:
return self._invoke_sync(model, api_params, user)
except Exception as e:
# Handle and map errors
self._handle_api_error(e)
def _invoke_stream(self, model: str, api_params: dict, user: Optional[str]) -> Generator[LLMResultChunk, None, None]:
"""Helper method for streaming API calls"""
# Implementation details for streaming calls
pass
def _invoke_sync(self, model: str, api_params: dict, user: Optional[str]) -> LLMResult:
"""Helper method for synchronous API calls"""
# Implementation details for synchronous calls
pass
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate that the credentials work for this specific model.
Called when a user tries to add or modify credentials.
"""
# Implementation for credential validation
pass
def get_num_tokens(self, model: str, credentials: dict,
prompt_messages: List[PromptMessage],
tools: Optional[List[PromptMessageTool]] = None) -> int:
"""
Estimate the number of tokens for given input.
Optional but recommended for accurate cost estimation.
"""
# Implementation for token counting
pass
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Define mapping from vendor-specific exceptions to Dify standard exceptions.
This helps standardize error handling across different providers.
"""
return {
InvokeAuthorizationError: [
# List vendor-specific auth errors here
],
# Other error mappings
}