Skip to main content
⚠️ 本文档由 AI 自动翻译。如有任何不准确之处,请参考英文原版

简介

本文档详细介绍了实现 Dify 模型插件所需的接口和数据结构。它作为开发者将 AI 模型与 Dify 平台集成的技术参考。
在深入阅读此 API 参考之前,我们建议先阅读模型设计规则模型插件介绍以获得概念上的理解。

模型提供商

每个模型提供商必须继承 __base.model_provider.ModelProvider 基类并实现凭据验证接口。

提供商凭据验证

def validate_provider_credentials(self, credentials: dict) -> None:
    """
    Validate provider credentials by making a test API call
    
    Parameters:
        credentials: Provider credentials as defined in `provider_credential_schema`
        
    Raises:
        CredentialsValidateFailedError: If validation fails
    """
    try:
        # Example implementation - validate using an LLM model instance
        model_instance = self.get_model_instance(ModelType.LLM)
        model_instance.validate_credentials(
            model="example-model", 
            credentials=credentials
        )
    except Exception as ex:
        logger.exception(f"Credential validation failed")
        raise CredentialsValidateFailedError(f"Invalid credentials: {str(ex)}")
credentials
dict
在提供商的 YAML 配置中 provider_credential_schema 下定义的凭据信息。 通常包括 api_keyorganization_id 等字段。
如果验证失败,您的实现必须抛出 CredentialsValidateFailedError 异常。这确保了在 Dify UI 中正确的错误处理。
对于预定义的模型提供商,您应该实现一个彻底的验证方法来验证凭据是否能与您的 API 一起工作。对于自定义模型提供商(每个模型都有自己的凭据),简化的实现就足够了。

模型

Dify 支持五种不同的模型类型,每种都需要实现特定的接口。然而,所有模型类型都有一些共同的要求。

通用接口

每个模型实现,无论类型如何,都必须实现这两个基本方法:

1. 模型凭据验证

def validate_credentials(self, model: str, credentials: dict) -> None:
    """
    Validate that the provided credentials work with the specified model
    
    Parameters:
        model: The specific model identifier (e.g., "gpt-4")
        credentials: Authentication details for the model
        
    Raises:
        CredentialsValidateFailedError: If validation fails
    """
    try:
        # Make a lightweight API call to verify credentials
        # Example: List available models or check account status
        response = self._api_client.validate_api_key(credentials["api_key"])
        
        # Verify the specific model is available if applicable
        if model not in response.get("available_models", []):
            raise CredentialsValidateFailedError(f"Model {model} is not available")
            
    except ApiException as e:
        raise CredentialsValidateFailedError(str(e))
model
string
required
要验证的特定模型标识符(例如,“gpt-4”、“claude-3-opus”)
credentials
dict
required
在提供商配置中定义的凭据信息

2. 错误映射

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
    """
    Map provider-specific exceptions to standardized Dify error types
    
    Returns:
        Dictionary mapping Dify error types to lists of provider exception types
    """
    return {
        InvokeConnectionError: [
            requests.exceptions.ConnectionError,
            requests.exceptions.Timeout,
            ConnectionRefusedError
        ],
        InvokeServerUnavailableError: [
            ServiceUnavailableError,
            HTTPStatusError
        ],
        InvokeRateLimitError: [
            RateLimitExceededError,
            QuotaExceededError
        ],
        InvokeAuthorizationError: [
            AuthenticationError,
            InvalidAPIKeyError,
            PermissionDeniedError
        ],
        InvokeBadRequestError: [
            InvalidRequestError,
            ValidationError
        ]
    }
InvokeConnectionError
class
网络连接失败、超时
InvokeServerUnavailableError
class
服务提供商宕机或不可用
InvokeRateLimitError
class
达到速率限制或配额限制
InvokeAuthorizationError
class
认证或权限问题
InvokeBadRequestError
class
无效的参数或请求
您也可以在代码中直接抛出这些标准化的错误类型,而不是依赖错误映射。这种方法让您对错误消息有更多的控制。

LLM 实现

要实现大型语言模型提供商,请继承 __base.large_language_model.LargeLanguageModel 基类并实现以下方法:

1. 模型调用

此核心方法处理对语言模型的流式和非流式 API 调用。
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    prompt_messages: list[PromptMessage], 
    model_parameters: dict,
    tools: Optional[list[PromptMessageTool]] = None, 
    stop: Optional[list[str]] = None,
    stream: bool = True, 
    user: Optional[str] = None
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
    """
    Invoke the language model
    """
    # Prepare API parameters
    api_params = self._prepare_api_parameters(
        model, 
        credentials, 
        prompt_messages, 
        model_parameters,
        tools, 
        stop
    )
    
    try:
        # Choose between streaming and non-streaming implementation
        if stream:
            return self._invoke_stream(model, api_params, user)
        else:
            return self._invoke_sync(model, api_params, user)
            
    except Exception as e:
        # Map errors using the error mapping property
        self._handle_api_error(e)

# Helper methods for streaming and non-streaming calls
def _invoke_stream(self, model, api_params, user):
    # Implement streaming call and yield chunks
    pass
    
def _invoke_sync(self, model, api_params, user):
    # Implement synchronous call and return complete result
    pass
model
string
required
模型标识符(例如,“gpt-4”、“claude-3”)
credentials
dict
required
API 的认证凭据
prompt_messages
list[PromptMessage]
required
Dify 标准化格式的消息列表:
  • 对于 completion 模型:包含单个 UserPromptMessage
  • 对于 chat 模型:根据需要包含 SystemPromptMessageUserPromptMessageAssistantPromptMessageToolPromptMessage
model_parameters
dict
required
模型特定参数(temperature、top_p 等),在模型的 YAML 配置中定义
tools
list[PromptMessageTool]
函数调用能力的工具定义
stop
list[string]
遇到时将停止模型生成的停止序列
stream
boolean
default:true
是否返回流式响应
user
string
用于 API 监控的用户标识符
stream=True
Generator[LLMResultChunk, None, None]
一个生成器,在响应块可用时逐个产出
stream=False
LLMResult
包含完整生成文本的完整响应对象
我们建议为流式和非流式调用实现单独的辅助方法,以保持代码的组织性和可维护性。

2. 令牌计数

def get_num_tokens(
    self, 
    model: str, 
    credentials: dict, 
    prompt_messages: list[PromptMessage],
    tools: Optional[list[PromptMessageTool]] = None
) -> int:
    """
    Calculate the number of tokens in the prompt
    """
    # Convert prompt_messages to the format expected by the tokenizer
    text = self._convert_messages_to_text(prompt_messages)
    
    try:
        # Use the appropriate tokenizer for this model
        tokenizer = self._get_tokenizer(model)
        return len(tokenizer.encode(text))
    except Exception:
        # Fall back to a generic tokenizer
        return self._get_num_tokens_by_gpt2(text)
如果模型不提供分词器,您可以使用基类的 _get_num_tokens_by_gpt2(text) 方法进行合理的近似估算。

3. 自定义模型 Schema(可选)

def get_customizable_model_schema(
    self, 
    model: str, 
    credentials: dict
) -> Optional[AIModelEntity]:
    """
    Get parameter schema for custom models
    """
    # For fine-tuned models, you might return the base model's schema
    if model.startswith("ft:"):
        base_model = self._extract_base_model(model)
        return self._get_predefined_model_schema(base_model)
    
    # For standard models, return None to use the predefined schema
    return None
此方法仅对支持自定义模型的提供商是必需的。它允许自定义模型从基础模型继承参数规则。

TextEmbedding 实现

文本嵌入模型将文本转换为捕获语义含义的高维向量,这对于检索、相似性搜索和分类非常有用。
要实现文本嵌入提供商,请继承 __base.text_embedding_model.TextEmbeddingModel 基类:

1. 核心嵌入方法

def _invoke(
    self, 
    model: str, 
    credentials: dict,
    texts: list[str], 
    user: Optional[str] = None
) -> TextEmbeddingResult:
    """
    Generate embedding vectors for multiple texts
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Handle batching if needed
    batch_size = self._get_batch_size(model)
    all_embeddings = []
    total_tokens = 0
    start_time = time.time()
    
    # Process in batches to avoid API limits
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        
        # Make API call to the embeddings endpoint
        response = client.embeddings.create(
            model=model,
            input=batch,
            user=user
        )
        
        # Extract embeddings from response
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
        
        # Track token usage
        total_tokens += response.usage.total_tokens
    
    # Calculate usage metrics
    elapsed_time = time.time() - start_time
    usage = self._create_embedding_usage(
        model=model,
        tokens=total_tokens,
        latency=elapsed_time
    )
    
    return TextEmbeddingResult(
        model=model,
        embeddings=all_embeddings,
        usage=usage
    )
model
string
required
嵌入模型标识符
credentials
dict
required
嵌入服务的认证凭据
texts
list[string]
required
要嵌入的文本输入列表
user
string
用于 API 监控的用户标识符
TextEmbeddingResult
object
required
包含以下内容的结构化响应:
  • model:用于嵌入的模型
  • embeddings:与输入文本对应的嵌入向量列表
  • usage:关于令牌使用和成本的元数据

2. 令牌计数方法

def get_num_tokens(
    self, 
    model: str, 
    credentials: dict, 
    texts: list[str]
) -> int:
    """
    Calculate the number of tokens in the texts to be embedded
    """
    # Join all texts to estimate token count
    combined_text = " ".join(texts)
    
    try:
        # Use the appropriate tokenizer for this model
        tokenizer = self._get_tokenizer(model)
        return len(tokenizer.encode(combined_text))
    except Exception:
        # Fall back to a generic tokenizer
        return self._get_num_tokens_by_gpt2(combined_text)
对于嵌入模型,准确的令牌计数对于成本估算很重要,但对功能不是关键的。_get_num_tokens_by_gpt2 方法为大多数模型提供了合理的近似值。

Rerank 实现

重排序模型通过根据与查询的相关性重新排列一组候选文档来帮助提高搜索质量,通常在初始检索阶段之后进行。
要实现重排序提供商,请继承 __base.rerank_model.RerankModel 基类:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    query: str, 
    docs: list[str], 
    score_threshold: Optional[float] = None, 
    top_n: Optional[int] = None,
    user: Optional[str] = None
) -> RerankResult:
    """
    Rerank documents based on relevance to the query
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Prepare request data
    request_data = {
        "query": query,
        "documents": docs,
    }
    
    # Call reranking API endpoint
    response = client.rerank(
        model=model,
        **request_data,
        user=user
    )
    
    # Process results
    ranked_results = []
    for i, result in enumerate(response.results):
        # Create RerankDocument for each result
        doc = RerankDocument(
            index=result.document_index,  # Original index in docs list
            text=docs[result.document_index],  # Original text
            score=result.relevance_score  # Relevance score
        )
        ranked_results.append(doc)
    
    # Sort by score in descending order
    ranked_results.sort(key=lambda x: x.score, reverse=True)
    
    # Apply score threshold filtering if specified
    if score_threshold is not None:
        ranked_results = [doc for doc in ranked_results if doc.score >= score_threshold]
    
    # Apply top_n limit if specified
    if top_n is not None and top_n > 0:
        ranked_results = ranked_results[:top_n]
    
    return RerankResult(
        model=model,
        docs=ranked_results
    )
model
string
required
重排序模型标识符
credentials
dict
required
API 的认证凭据
query
string
required
搜索查询文本
docs
list[string]
required
要重排序的文档文本列表
score_threshold
float
用于过滤结果的可选最小分数阈值
top_n
int
返回结果数量的可选限制
user
string
用于 API 监控的用户标识符
RerankResult
object
required
包含以下内容的结构化响应:
  • model:用于重排序的模型
  • docs:包含索引、文本和分数的 RerankDocument 对象列表
重排序可能计算成本较高,特别是对于大型文档集。为大型文档集合实现批处理以避免超时或过度资源消耗。

Speech2Text 实现

语音转文本模型将音频文件中的口语转换为书面文本,支持转录服务、语音命令和无障碍功能等应用。
要实现语音转文本提供商,请继承 __base.speech2text_model.Speech2TextModel 基类:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    file: IO[bytes], 
    user: Optional[str] = None
) -> str:
    """
    Convert speech audio to text
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    try:
        # Determine the file format
        file_format = self._detect_audio_format(file)
        
        # Prepare the file for API submission
        # Most APIs require either a file path or binary data
        audio_data = file.read()
        
        # Call the speech-to-text API
        response = client.audio.transcriptions.create(
            model=model,
            file=("audio.mp3", audio_data),  # Adjust filename based on actual format
            user=user
        )
        
        # Extract and return the transcribed text
        return response.text
        
    except Exception as e:
        # Map to appropriate error type
        self._handle_api_error(e)
        
    finally:
        # Reset file pointer for potential reuse
        file.seek(0)
model
string
required
语音转文本模型标识符
credentials
dict
required
API 的认证凭据
file
IO[bytes]
required
包含要转录的音频的二进制文件对象
user
string
用于 API 监控的用户标识符
text
string
required
从音频文件转录的文本
音频格式检测对于正确处理不同文件类型很重要。考虑实现一个辅助方法来从文件头检测格式,如示例所示。
一些语音转文本 API 有文件大小限制。如有必要,考虑为大型音频文件实现分块处理。

Text2Speech 实现

文本转语音模型将书面文本转换为自然发音的语音,支持语音助手、屏幕阅读器和音频内容生成等应用。
要实现文本转语音提供商,请继承 __base.text2speech_model.Text2SpeechModel 基类:
def _invoke(
    self, 
    model: str, 
    credentials: dict, 
    content_text: str, 
    streaming: bool,
    user: Optional[str] = None
) -> Union[bytes, Generator[bytes, None, None]]:
    """
    Convert text to speech audio
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Get voice settings based on model
    voice = self._get_voice_for_model(model)
    
    try:
        # Choose implementation based on streaming preference
        if streaming:
            return self._stream_audio(
                client=client,
                model=model,
                text=content_text,
                voice=voice,
                user=user
            )
        else:
            return self._generate_complete_audio(
                client=client,
                model=model,
                text=content_text,
                voice=voice,
                user=user
            )
    except Exception as e:
        self._handle_api_error(e)
model
string
required
文本转语音模型标识符
credentials
dict
required
API 的认证凭据
content_text
string
required
要转换为语音的文本内容
streaming
boolean
required
是返回流式音频还是完整文件
user
string
用于 API 监控的用户标识符
streaming=True
Generator[bytes, None, None]
一个生成器,在音频块可用时逐个产出
streaming=False
bytes
作为字节的完整音频数据
大多数文本转语音 API 要求您在指定模型的同时指定语音。考虑在 Dify 的模型标识符和提供商的语音选项之间实现映射。
长文本输入可能需要分块以获得更好的语音合成质量。考虑实现文本预处理来正确处理标点符号、数字和特殊字符。

Moderation 实现

内容审核模型分析内容中潜在的有害、不当或不安全的材料,帮助维护平台安全和内容政策。
要实现内容审核提供商,请继承 __base.moderation_model.ModerationModel 基类:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    text: str, 
    user: Optional[str] = None
) -> bool:
    """
    Analyze text for harmful content
    
    Returns:
        bool: False if the text is safe, True if it contains harmful content
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    try:
        # Call moderation API
        response = client.moderations.create(
            model=model,
            input=text,
            user=user
        )
        
        # Check if any categories were flagged
        result = response.results[0]
        
        # Return True if flagged in any category, False if safe
        return result.flagged
        
    except Exception as e:
        # Log the error but default to safe if there's an API issue
        # This is a conservative approach - production systems might want
        # different fallback behavior
        logger.error(f"Moderation API error: {str(e)}")
        return False
model
string
required
内容审核模型标识符
credentials
dict
required
API 的认证凭据
text
string
required
要分析的文本内容
user
string
用于 API 监控的用户标识符
result
boolean
required
表示内容安全性的布尔值:
  • False:内容是安全的
  • True:内容包含有害材料
内容审核通常用作安全机制。在实现解决方案时,请考虑漏报(让有害内容通过)与误报(阻止安全内容)的影响。
许多内容审核 API 提供详细的类别分数而不仅仅是二进制结果。如果您的应用需要,考虑扩展此实现以返回关于特定有害内容类别的更详细信息。

实体

PromptMessageRole

消息角色
class PromptMessageRole(Enum):
    """
    Enum class for prompt message.
    """
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"

PromptMessageContentType

消息内容类型,分为纯文本和图片。
class PromptMessageContentType(Enum):
    """
    Enum class for prompt message content type.
    """
    TEXT = 'text'
    IMAGE = 'image'

PromptMessageContent

消息内容基类,仅用于参数声明,不能初始化。
class PromptMessageContent(BaseModel):
    """
    Model class for prompt message content.
    """
    type: PromptMessageContentType
    data: str  # Content data
目前支持两种类型:文本和图片,并且可以同时支持文本和多张图片。 您需要分别初始化 TextPromptMessageContentImagePromptMessageContent

TextPromptMessageContent

class TextPromptMessageContent(PromptMessageContent):
    """
    Model class for text prompt message content.
    """
    type: PromptMessageContentType = PromptMessageContentType.TEXT
当传入文本和图片时,文本需要构造成此实体作为 content 列表的一部分。

ImagePromptMessageContent

class ImagePromptMessageContent(PromptMessageContent):
    """
    Model class for image prompt message content.
    """
    class DETAIL(Enum):
        LOW = 'low'
        HIGH = 'high'

    type: PromptMessageContentType = PromptMessageContentType.IMAGE
    detail: DETAIL = DETAIL.LOW  # Resolution
当传入文本和图片时,图片需要构造成此实体作为 content 列表的一部分。 data 可以是 url 或图片的 base64 编码字符串。

PromptMessage

所有角色消息体的基类,仅用于参数声明,不能初始化。
class PromptMessage(ABC, BaseModel):
    """
    Model class for prompt message.
    """
    role: PromptMessageRole  # Message role
    content: Optional[str | list[PromptMessageContent]] = None  # Supports two types: string and content list. The content list is for multimodal needs, see PromptMessageContent for details.
    name: Optional[str] = None  # Name, optional.

UserPromptMessage

UserMessage 消息体,表示用户消息。
class UserPromptMessage(PromptMessage):
    """
    Model class for user prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.USER

AssistantPromptMessage

表示模型响应消息,通常用于 few-shots 或聊天历史输入。
class AssistantPromptMessage(PromptMessage):
    """
    Model class for assistant prompt message.
    """
    class ToolCall(BaseModel):
        """
        Model class for assistant prompt message tool call.
        """
        class ToolCallFunction(BaseModel):
            """
            Model class for assistant prompt message tool call function.
            """
            name: str  # Tool name
            arguments: str  # Tool parameters

        id: str  # Tool ID, only effective for OpenAI tool call, a unique ID for tool invocation, the same tool can be called multiple times
        type: str  # Default is function
        function: ToolCallFunction  # Tool call information

    role: PromptMessageRole = PromptMessageRole.ASSISTANT
    tool_calls: list[ToolCall] = []  # Model's tool call results (only returned when tools are passed in and the model decides to call them)
这里的 tool_calls 是在向模型传入 tools 后模型返回的 tool call 列表。

SystemPromptMessage

表示系统消息,通常用于为模型设置系统指令。
class SystemPromptMessage(PromptMessage):
    """
    Model class for system prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.SYSTEM

ToolPromptMessage

表示工具消息,用于在工具执行后将结果传递给模型以进行下一步规划。
class ToolPromptMessage(PromptMessage):
    """
    Model class for tool prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.TOOL
    tool_call_id: str  # Tool call ID, if OpenAI tool call is not supported, you can also pass in the tool name
基类的 content 传入工具执行结果。

PromptMessageTool

class PromptMessageTool(BaseModel):
    """
    Model class for prompt message tool.
    """
    name: str  # Tool name
    description: str  # Tool description
    parameters: dict  # Tool parameters dict


LLMResult

class LLMResult(BaseModel):
    """
    Model class for llm result.
    """
    model: str  # Actually used model
    prompt_messages: list[PromptMessage]  # Prompt message list
    message: AssistantPromptMessage  # Reply message
    usage: LLMUsage  # Tokens used and cost information
    system_fingerprint: Optional[str] = None  # Request fingerprint, refer to OpenAI parameter definition

LLMResultChunkDelta

流式响应中每次迭代的 Delta 实体
class LLMResultChunkDelta(BaseModel):
    """
    Model class for llm result chunk delta.
    """
    index: int  # Sequence number
    message: AssistantPromptMessage  # Reply message
    usage: Optional[LLMUsage] = None  # Tokens used and cost information, only returned in the last message
    finish_reason: Optional[str] = None  # Completion reason, only returned in the last message

LLMResultChunk

流式响应中的迭代实体
class LLMResultChunk(BaseModel):
    """
    Model class for llm result chunk.
    """
    model: str  # Actually used model
    prompt_messages: list[PromptMessage]  # Prompt message list
    system_fingerprint: Optional[str] = None  # Request fingerprint, refer to OpenAI parameter definition
    delta: LLMResultChunkDelta  # Changes in content for each iteration

LLMUsage

class LLMUsage(ModelUsage):
    """
    Model class for llm usage.
    """
    prompt_tokens: int  # Tokens used by prompt
    prompt_unit_price: Decimal  # Prompt unit price
    prompt_price_unit: Decimal  # Prompt price unit, i.e., unit price based on how many tokens
    prompt_price: Decimal  # Prompt cost
    completion_tokens: int  # Tokens used by completion
    completion_unit_price: Decimal  # Completion unit price
    completion_price_unit: Decimal  # Completion price unit, i.e., unit price based on how many tokens
    completion_price: Decimal  # Completion cost
    total_tokens: int  # Total tokens used
    total_price: Decimal  # Total cost
    currency: str  # Currency unit
    latency: float  # Request time (s)

TextEmbeddingResult

class TextEmbeddingResult(BaseModel):
    """
    Model class for text embedding result.
    """
    model: str  # Actually used model
    embeddings: list[list[float]]  # Embedding vector list, corresponding to the input texts list
    usage: EmbeddingUsage  # Usage information

EmbeddingUsage

class EmbeddingUsage(ModelUsage):
    """
    Model class for embedding usage.
    """
    tokens: int  # Tokens used
    total_tokens: int  # Total tokens used
    unit_price: Decimal  # Unit price
    price_unit: Decimal  # Price unit, i.e., unit price based on how many tokens
    total_price: Decimal  # Total cost
    currency: str  # Currency unit
    latency: float  # Request time (s)

RerankResult

class RerankResult(BaseModel):
    """
    Model class for rerank result.
    """
    model: str  # Actually used model
    docs: list[RerankDocument]  # List of reranked segments        

RerankDocument

class RerankDocument(BaseModel):
    """
    Model class for rerank document.
    """
    index: int  # Original sequence number
    text: str  # Segment text content
    score: float  # Score

相关资源


Edit this page | Report an issue