Skip to main content
⚠️ 本文档由 AI 自动翻译。如有任何不准确之处,请参考英文原版
自定义模型是指您自行部署或配置的 LLM。本文档以 Xinference 模型为例,演示如何将自定义模型集成到您的模型插件中。 默认情况下,自定义模型自动包含两个参数——其模型类型模型名称——无需在供应商 YAML 文件中进行额外定义。 您无需在供应商配置文件中实现 validate_provider_credential。在运行时,根据用户选择的模型类型或模型名称,Dify 会自动调用相应模型层的 validate_credentials 方法来验证凭据。

集成自定义模型插件

以下是集成自定义模型的步骤:
  1. 创建模型供应商文件
    确定您的自定义模型将包含的模型类型。
  2. 按模型类型创建代码文件
    根据模型的类型(例如 llmtext_embedding),创建单独的代码文件。确保每种模型类型都组织成不同的逻辑层,以便于维护和未来扩展。
  3. 开发模型调用逻辑
    在每个模型类型模块中,创建一个以该模型类型命名的 Python 文件(例如 llm.py)。在文件中定义一个类,实现特定的模型逻辑,符合系统的模型接口规范。
  4. 调试插件
    为新的供应商功能编写单元测试和集成测试,确保所有组件按预期工作。

1. 创建模型供应商文件

在插件的 /provider 目录中,创建一个 xinference.yaml 文件。 Xinference 系列模型支持 LLMText EmbeddingRerank 模型类型,因此您的 xinference.yaml 必须包含所有三种类型。 示例:
provider: xinference  # Identifies the provider
label:                # Display name; can set both en_US (English) and zh_Hans (Chinese). If zh_Hans is not set, en_US is used by default.
  en_US: Xorbits Inference
icon_small:           # Small icon; store in the _assets folder of this provider's directory. The same multi-language logic applies as with label.
  en_US: icon_s_en.svg
icon_large:           # Large icon
  en_US: icon_l_en.svg
help:                 # Help information
  title:
    en_US: How to deploy Xinference
    zh_Hans: 如何部署 Xinference
  url:
    en_US: https://github.com/xorbitsai/inference

supported_model_types:  # Model types Xinference supports: LLM/Text Embedding/Rerank
- llm
- text-embedding
- rerank

configurate_methods:     # Xinference is locally deployed and does not offer predefined models. Refer to its documentation to learn which model to use. Thus, we choose a customizable-model approach.
- customizable-model

provider_credential_schema:
  credential_form_schemas:
接下来,定义 provider_credential_schema。由于 Xinference 支持文本生成、嵌入和重排序模型,您可以按如下方式配置:
provider_credential_schema:
  credential_form_schemas:
  - variable: model_type
    type: select
    label:
      en_US: Model type
      zh_Hans: 模型类型
    required: true
    options:
    - value: text-generation
      label:
        en_US: Language Model
        zh_Hans: 语言模型
    - value: embeddings
      label:
        en_US: Text Embedding
    - value: reranking
      label:
        en_US: Rerank
Xinference 中的每个模型都需要一个 model_name
  - variable: model_name
    type: text-input
    label:
      en_US: Model name
      zh_Hans: 模型名称
    required: true
    placeholder:
      zh_Hans: 填写模型名称
      en_US: Input model name
由于 Xinference 必须在本地部署,用户需要提供服务器地址(server_url)和模型 UID。例如:
  - variable: server_url
    label:
      zh_Hans: 服务器 URL
      en_US: Server url
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
      en_US: Enter the url of your Xinference, for example https://example.com/xxx

  - variable: model_uid
    label:
      zh_Hans: 模型 UID
      en_US: Model uid
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入你的 Model UID
      en_US: Enter the model uid
定义完这些参数后,自定义模型供应商的 YAML 配置就完成了。接下来,为此配置中定义的每个模型创建功能代码文件。

2. 开发模型代码

由于 Xinference 支持 llm、rerank、speech2text 和 tts,您应该在 /models 下创建相应的目录,每个目录包含其各自的功能代码。 以下是 llm 类型模型的示例。您需要创建一个名为 llm.py 的文件,然后定义一个类——例如 XinferenceAILargeLanguageModel——继承自 __base.large_language_model.LargeLanguageModel。该类应包含:
  • LLM 调用
调用 LLM 的核心方法,支持流式和同步响应:
def _invoke(
    self,
    model: str,
    credentials: dict,
    prompt_messages: list[PromptMessage],
    model_parameters: dict,
    tools: Optional[list[PromptMessageTool]] = None,
    stop: Optional[list[str]] = None,
    stream: bool = True,
    user: Optional[str] = None
) -> Union[LLMResult, Generator]:
    """
    Invoke the large language model.

    :param model: model name
    :param credentials: model credentials
    :param prompt_messages: prompt messages
    :param model_parameters: model parameters
    :param tools: tools for tool calling
    :param stop: stop words
    :param stream: determines if response is streamed
    :param user: unique user id
    :return: full response or a chunk generator
    """
您需要两个单独的函数来处理流式和同步响应。Python 将任何包含 yield 的函数视为返回 Generator 类型的生成器,因此最好将它们分开:
def _invoke(self, stream: bool, **kwargs) -> Union[LLMResult, Generator]:
    if stream:
        return self._handle_stream_response(**kwargs)
    return self._handle_sync_response(**kwargs)

def _handle_stream_response(self, **kwargs) -> Generator:
    for chunk in response:
        yield chunk

def _handle_sync_response(self, **kwargs) -> LLMResult:
    return LLMResult(**response)
  • 预计算输入令牌
如果您的模型不提供令牌计数接口,只需返回 0:
def get_num_tokens(
    self,
    model: str,
    credentials: dict,
    prompt_messages: list[PromptMessage],
    tools: Optional[list[PromptMessageTool]] = None
) -> int:
    """
    Get the number of tokens for the given prompt messages.
    """
    return 0
或者,您可以从 AIModel 基类调用 self._get_num_tokens_by_gpt2(text: str),它使用 GPT-2 分词器。请记住这只是一个近似值,可能与您的模型不完全匹配。
  • 验证模型凭据
类似于供应商级别的凭据检查,但范围限定于单个模型:
def validate_credentials(self, model: str, credentials: dict) -> None:
    """
    Validate model credentials.
    """
  • 动态模型参数模式
预定义模型不同,没有 YAML 定义模型支持哪些参数。您必须动态生成参数模式。 例如,Xinference 支持 max_tokenstemperaturetop_p。某些其他供应商(例如 OpenLLM)可能仅对某些模型支持 top_k 等参数。这意味着您需要根据每个模型的能力调整您的模式:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
    """
        used to define customizable model schema
    """
    rules = [
        ParameterRule(
            name='temperature', type=ParameterType.FLOAT,
            use_template='temperature',
            label=I18nObject(
                zh_Hans='温度', en_US='Temperature'
            )
        ),
        ParameterRule(
            name='top_p', type=ParameterType.FLOAT,
            use_template='top_p',
            label=I18nObject(
                zh_Hans='Top P', en_US='Top P'
            )
        ),
        ParameterRule(
            name='max_tokens', type=ParameterType.INT,
            use_template='max_tokens',
            min=1,
            default=512,
            label=I18nObject(
                zh_Hans='最大生成长度', en_US='Max Tokens'
            )
        )
    ]

    # if model is A, add top_k to rules
    if model == 'A':
        rules.append(
            ParameterRule(
                name='top_k', type=ParameterType.INT,
                use_template='top_k',
                min=1,
                default=50,
                label=I18nObject(
                    zh_Hans='Top K', en_US='Top K'
                )
            )
        )

    """
        some NOT IMPORTANT code here
    """

    entity = AIModelEntity(
        model=model,
        label=I18nObject(
            en_US=model
        ),
        fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
        model_type=model_type,
        model_properties={ 
            ModelPropertyKey.MODE:  ModelType.LLM,
        },
        parameter_rules=rules
    )

    return entity
  • 错误映射
当模型调用过程中发生错误时,将其映射到运行时识别的适当 InvokeError 类型。这使 Dify 能够以标准化的方式处理不同的错误: 运行时错误:
•	`InvokeConnectionError`
•	`InvokeServerUnavailableError`
•	`InvokeRateLimitError`
•	`InvokeAuthorizationError`
•	`InvokeBadRequestError`
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
    """
    Map model invocation errors to unified error types.
    The key is the error type thrown to the caller.
    The value is the error type thrown by the model, which needs to be mapped to a
    unified Dify error for consistent handling.
    """
    # return {
    #   InvokeConnectionError: [requests.exceptions.ConnectionError],
    #   ...
    # }
有关接口方法的更多详细信息,请参阅模型文档 要查看本指南中讨论的完整代码文件,请访问 GitHub 仓库

3. 调试插件

完成开发后,测试插件以确保其正常运行。有关更多详细信息,请参阅:

调试插件

4. 发布插件

如果您想在 Dify Marketplace 上列出此插件,请参阅: 发布到 Dify Marketplace

探索更多

快速开始: 插件端点文档:
Edit this page | Report an issue