カスタムモデルのインテグレーション

イントロダクション

ベンダー統合が完了した後、次にベンダーの下でモデルのインテグレーションを行います。ここでは、全体のプロセスを理解するために、例としてXinferenceを使用して、段階的にベンダーのインテグレーションを完了します。

注意が必要なのは、カスタムモデルの場合、各モデルのインテグレーションには完全なベンダークレデンシャルの記入が必要です。

事前定義モデルとは異なり、カスタムベンダーのインテグレーション時には常に以下の2つのパラメータが存在し、ベンダー yaml に定義する必要はありません。

前述したように、ベンダーはvalidate_provider_credentialを実装する必要はなく、Runtimeがユーザーが選択したモデルタイプとモデル名に基づいて、対応するモデル層のvalidate_credentialsを呼び出して検証を行います。

ベンダー yaml の作成

まず、インテグレーションを行うベンダーがどのタイプのモデルをサポートしているかを確認します。

現在サポートされているモデルタイプは以下の通りです:

  • llm テキスト生成モデル

  • text_embedding テキスト Embedding モデル

  • rerank Rerank モデル

  • speech2text 音声からテキスト変換

  • tts テキストから音声変換

  • moderation モデレーション

XinferenceLLMText EmbeddingRerankをサポートしているため、xinference.yamlを作成します。

provider: xinference # ベンダー識別子
label: # ベンダー表示名、en_US 英語、zh_Hans 中国語の両方の言語で設定可能、zh_Hans が設定されていない場合は en_US がデフォルト
  en_US: Xorbits Inference
icon_small: # 小アイコン、他のベンダーのアイコンを参考にし、対応するベンダー実装ディレクトリの _assets ディレクトリに保存
  en_US: icon_s_en.svg
icon_large: # 大アイコン
  en_US: icon_l_en.svg
help: # ヘルプ
  title:
    en_US: How to deploy Xinference
    zh_Hans: 如何部署 Xinference
  url:
    en_US: https://github.com/xorbitsai/inference
supported_model_types: # サポートされるモデルタイプ、XinferenceはLLM/Text Embedding/Rerankをサポート
- llm
- text-embedding
- rerank
configurate_methods: # Xinferenceはローカルデプロイのベンダーであり、事前定義モデルがないため、必要なモデルを自分でデプロイする必要があるので、ここではカスタムモデルのみサポート
- customizable-model
provider_credential_schema:
  credential_form_schemas:

その後、Xinferenceでモデルを定義するために必要なクレデンシャルを考えます。

  • 3つの異なるモデルをサポートするため、model_typeを使用してこのモデルのタイプを指定する必要があります。3つのタイプがあるので、次のように記述します。

provider_credential_schema:
  credential_form_schemas:
  - variable: model_type
    type: select
    label:
      en_US: Model type
      zh_Hans: 模型类型
    required: true
    options:
    - value: text-generation
      label:
        en_US: Language Model
        zh_Hans: 言語モデル
    - value: embeddings
      label:
        en_US: Text Embedding
    - value: reranking
      label:
        en_US: Rerank
  • 各モデルには独自の名称model_nameがあるため、ここで定義する必要があります。

  - variable: model_name
    type: text-input
    label:
      en_US: Model name
      zh_Hans: モデル名
    required: true
    placeholder:
      zh_Hans: 填写模型名称
      en_US: Input model name
  • Xinferenceのローカルデプロイのアドレスを記入します。

  - variable: server_url
    label:
      zh_Hans: 服务器URL
      en_US: Server url
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
      en_US: Enter the url of your Xinference, for example https://example.com/xxx
  • 各モデルには一意の model_uid があるため、ここで定義する必要があります。

  - variable: model_uid
    label:
      zh_Hans: 模型 UID
      en_US: Model uid
    type: text-input
    required: true
    placeholder:
      zh_Hans: 在此输入您的 Model UID
      en_US: Enter the model uid

これで、ベンダーの基本定義が完了しました。

モデルコードの作成

次に、llmタイプを例にとって、xinference.llm.llm.pyを作成します。

llm.py内で、Xinference LLM クラスを作成し、XinferenceAILargeLanguageModel(任意の名前)と名付けて、__base.large_language_model.LargeLanguageModel基底クラスを継承し、以下のメソッドを実装します:

  • LLM 呼び出し

    LLM 呼び出しのコアメソッドを実装し、ストリームレスポンスと同期レスポンスの両方をサポートします。

    def _invoke(self, model: str, credentials: dict,
                prompt_messages: list[PromptMessage], model_parameters: dict,
                tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
                stream: bool = True, user: Optional[str] = None) \
            -> Union[LLMResult, Generator]:
        """
        Invoke large language model
    
        :param model: model name
        :param credentials: model credentials
        :param prompt_messages: prompt messages
        :param model_parameters: model parameters
        :param tools: tools for tool calling
        :param stop: stop words
        :param stream: is stream response
        :param user: unique user id
        :return: full response or stream response chunk generator result
        """

    実装時には、同期レスポンスとストリームレスポンスを処理するために2つの関数を使用してデータを返す必要があります。Pythonはyieldキーワードを含む関数をジェネレータ関数として認識し、返されるデータ型は固定でジェネレーターになります。そのため、同期レスポンスとストリームレスポンスは別々に実装する必要があります。以下のように実装します(例では簡略化されたパラメータを使用していますが、実際の実装では上記のパラメータリストに従って実装してください):

    def _invoke(self, stream: bool, **kwargs) \
            -> Union[LLMResult, Generator]:
        if stream:
              return self._handle_stream_response(**kwargs)
        return self._handle_sync_response(**kwargs)
    
    def _handle_stream_response(self, **kwargs) -> Generator:
        for chunk in response:
              yield chunk
    def _handle_sync_response(self, **kwargs) -> LLMResult:
        return LLMResult(**response)
  • 予測トークン数の計算

    モデルが予測トークン数の計算インターフェースを提供していない場合、直接0を返すことができます。

    def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                     tools: Optional[list[PromptMessageTool]] = None) -> int:
      """
      Get number of tokens for given prompt messages
    
      :param model: model name
      :param credentials: model credentials
      :param prompt_messages: prompt messages
      :param tools: tools for tool calling
      :return:
      """

    時には、直接0を返す必要がない場合もあります。その場合はself._get_num_tokens_by_gpt2(text: str)を使用して予測トークン数を取得することができます。このメソッドはAIModel基底クラスにあり、GPT2のTokenizerを使用して計算を行いますが、代替方法として使用されるものであり、完全に正確ではありません。

  • モデルクレデンシャル検証

    ベンダークレデンシャル検証と同様に、ここでは個々のモデルについて検証を行います。

    def validate_credentials(self, model: str, credentials: dict) -> None:
        """
        Validate model credentials
    
        :param model: model name
        :param credentials: model credentials
        :return:
        """
  • モデルパラメータスキーマ

    カスタムタイプとは異なり、yamlファイルでモデルがサポートするパラメータを定義していないため、動的にモデルパラメータのスキーマを生成する必要があります。

    例えば、Xinferenceはmax_tokenstemperaturetop_pの3つのモデルパラメータをサポートしています。

    しかし、ベンダーによっては異なるモデルに対して異なるパラメータをサポートしている場合があります。例えば、ベンダーOpenLLMtop_kをサポートしていますが、全てのモデルがtop_kをサポートしているわけではありません。ここでは、例としてAモデルがtop_kをサポートし、Bモデルがtop_kをサポートしていない場合、以下のように動的にモデルパラメータのスキーマを生成します:

    def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
        """
            used to define customizable model schema
        """
        rules = [
            ParameterRule(
                name='temperature', type=ParameterType.FLOAT,
                use_template='temperature',
                label=I18nObject(
                    zh_Hans='温度', en_US='Temperature'
                )
            ),
            ParameterRule(
                name='top_p', type=ParameterType.FLOAT,
                use_template='top_p',
                label=I18nObject(
                    zh_Hans='Top P', en_US='Top P'
                )
            ),
            ParameterRule(
                name='max_tokens', type=ParameterType.INT,
                use_template='max_tokens',
                min=1,
                default=512,
                label=I18nObject(
                    zh_Hans='最大生成长度', en_US='Max Tokens'
                )
            )
        ]
    
        # if model is A, add top_k to rules
        if model == 'A':
            rules.append(
                ParameterRule(
                    name='top_k', type=ParameterType.INT,
                    use_template='top_k',
                    min=1,
                    default=50,
                    label=I18nObject(
                        zh_Hans='Top K', en_US='Top K'
                    )
                )
            )
    
        """
            some NOT IMPORTANT code here
        """
    
        entity = AIModelEntity(
            model=model,
            label=I18nObject(
                en_US=model
            ),
            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
            model_type=model_type,
            model_properties={ 
                ModelPropertyKey.MODE:  ModelType.LLM,
            },
            parameter_rules=rules
        )
    
        return entity
  • 呼び出しエラーマッピングテーブル

    モデル呼び出し時にエラーが発生した場合、Runtimeが指定するInvokeErrorタイプにマッピングする必要があります。これにより、Difyは異なるエラーに対して異なる後続処理を行うことができます。

    Runtime エラー:

    • InvokeConnectionError 呼び出し接続エラー

    • InvokeServerUnavailableError 呼び出しサービスが利用不可

    • InvokeRateLimitError 呼び出し回数制限に達した

    • InvokeAuthorizationError 認証エラー

    • InvokeBadRequestError 不正なリクエストパラメータ

    @property
    def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
        """
        Map model invoke error to uni ```

-> dict[type[呼び出しエラー], list[type[例外]]]: """ モデル呼び出しエラーを統一エラーにマッピングする キーは呼び出し元に投げられるエラータイプ バリューはモデルが投げるエラータイプであり、 呼び出し元に対して統一エラータイプに変換する必要があります。

    :return: 呼び出しエラーのマッピング
    """

インターフェース方法の詳細については:[インターフェース](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/docs/zh_Hans/interfaces.md)をご覧ください。具体的な実装例については、[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)を参照してください。

Last updated