カスタムモデルとは、自身でデプロイまたは設定する必要があるLLMを指します。本稿では、Xinferenceモデルを例に、モデルプラグイン内でカスタムモデルを統合する方法を説明します。

カスタムモデルはデフォルトでモデルタイプとモデル名の2つのパラメータを含んでおり、プロバイダーのYAMLファイルで定義する必要はありません。

プロバイダー設定ファイルでは、validate_provider_credentialを実装する必要はありません。Runtimeは、ユーザーが選択したモデルタイプまたはモデル名に基づいて、対応するモデルレイヤーのvalidate_credentialsメソッドを自動的に呼び出して検証を行います。

カスタムモデルプラグインの統合

カスタムモデルの統合は、以下のステップに分かれます:

  1. モデルプロバイダーファイルの作成

    カスタムモデルに含まれるモデルタイプを明確にします。

  2. モデルタイプに基づいたコードファイルの作成

    モデルのタイプ(例:llmまたはtext_embedding)に基づいてコードファイルを作成します。各モデルタイプが独立したロジックレイヤーを持つようにし、保守と拡張を容易にします。

  3. 異なるモデルモジュールに基づいたモデル呼び出しコードの記述

    対応するモデルタイプモジュールの下に、モデルタイプと同名のPythonファイル(例:llm.py)を作成します。ファイル内で具体的なモデルロジックを実装するクラスを定義し、そのクラスはシステムのモデルインターフェース仕様に準拠する必要があります。

  4. プラグインのデバッグ

    新たに追加されたプロバイダー機能に対して単体テストと結合テストを作成し、すべての機能モジュールが期待通りに動作することを確認します。


1. モデルプロバイダーファイルの作成

プラグインプロジェクトの/providerパスの下に、新しいxinference.yamlファイルを作成します。

XinferenceファミリーモデルはLLMText EmbeddingRerankモデルタイプをサポートしているため、xinference.yamlファイルにこれらのモデルタイプを含める必要があります。

サンプルコード:

provider: xinference # プロバイダー識別子を決定
label: # プロバイダー表示名。en_US (英語)、zh_Hans (中国語)の2言語を設定可能。zh_Hans を設定しない場合はデフォルトで en_US が使用されます。
  en_US: Xorbits Inference
  zh_Hans: Xorbits Inference # 中国語の表示名(例として英語と同じにしていますが、通常は中国語訳)
icon_small: # 小アイコン。他のプロバイダーのアイコンを参考に、対応するプロバイダー実装ディレクトリ下の _assets ディレクトリに保存。中英ポリシーは label と同様。
  en_US: icon_s_en.svg
icon_large: # 大アイコン
  en_US: icon_l_en.svg
help: # ヘルプ
  title:
    en_US: How to deploy Xinference
    zh_Hans: Xinferenceのデプロイ方法
  url:
    en_US: https://github.com/xorbitsai/inference
supported_model_types: # サポートされるモデルタイプ。XinferenceはLLM/Text Embedding/Rerankを同時にサポートします。
- llm
- text-embedding
- rerank
configurate_methods: # Xinferenceはローカルデプロイのプロバイダーであり、事前定義されたモデルはありません。どのモデルを使用するかはXinferenceのドキュメントに従ってデプロイする必要があるため、ここのメソッドはカスタムモデルです。
- customizable-model
provider_credential_schema:
  credential_form_schemas:

次に、provider_credential_schemaフィールドを定義する必要があります。Xinferencetext-generationembeddingsrerankingモデルをサポートしています。サンプルコードは以下の通りです:

provider_credential_schema:
  credential_form_schemas:
  - variable: model_type
    type: select
    label:
      en_US: Model type
      zh_Hans: モデルタイプ
    required: true
    options:
    - value: text-generation
      label:
        en_US: Language Model
        zh_Hans: 言語モデル
    - value: embeddings
      label:
        en_US: Text Embedding
        zh_Hans: テキスト埋め込み # 中国語の表示名(例)
    - value: reranking
      label:
        en_US: Rerank
        zh_Hans: リランク # 中国語の表示名(例)

Xinferenceの各モデルでは、名前model_nameを定義する必要があります。

  - variable: model_name
    type: text-input
    label:
      en_US: Model name
      zh_Hans: モデル名
    required: true
    placeholder:
      zh_Hans: モデル名を入力してください
      en_US: Input model name

Xinferenceモデルでは、ユーザーがモデルのローカルデプロイアドレスを入力する必要があります。プラグイン内では、Xinferenceモデルのローカルデプロイアドレス(server_url)とモデルUIDを入力できる場所を提供する必要があります。サンプルコードは以下の通りです:

  - variable: server_url
    label:
      zh_Hans: サーバーURL
      en_US: Server url
    type: text-input
    required: true
    placeholder:
      zh_Hans: ここにXinferenceのサーバーアドレスを入力してください(例:https://example.com/xxx)
      en_US: Enter the url of your Xinference, for example https://example.com/xxx
  - variable: model_uid
    label:
      zh_Hans: モデルUID
      en_US: Model uid
    type: text-input
    required: true
    placeholder:
      zh_Hans: ここにあなたのモデルUIDを入力してください
      en_US: Enter the model uid

すべてのパラメータを入力すると、カスタムモデルプロバイダーのYAML設定ファイルの作成が完了します。次に、設定ファイル内で定義されたモデルに具体的な機能コードファイルを追加する必要があります。

2. モデルコードの記述

Xinferenceモデルプロバイダーのモデルタイプには、llm、rerank、speech2text、ttsタイプが含まれるため、/modelsパスの下に各モデルタイプごとに独立したグループを作成し、対応する機能コードファイルを作成する必要があります。

以下では、llmタイプを例に、llm.pyコードファイルの作成方法を説明します。コード作成時には、XinferenceAILargeLanguageModelという名前のXinference LLMクラスを作成し、__base.large_language_model.LargeLanguageModelベースクラスを継承し、以下のいくつかのメソッドを実装する必要があります:

  • LLM呼び出し

    LLM呼び出しのコアメソッドであり、ストリーミングと同期の両方のレスポンスをサポートします。

def _invoke(self, model: str, credentials: dict,
            prompt_messages: list[PromptMessage], model_parameters: dict,
            tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
            stream: bool = True, user: Optional[str] = None) \
        -> Union[LLMResult, Generator]:
    """
    大規模言語モデルを呼び出す

    :param model: モデル名
    :param credentials: モデルの認証情報
    :param prompt_messages: プロンプトメッセージ
    :param model_parameters: モデルパラメータ
    :param tools: ツール呼び出し用のツール
    :param stop: ストップワード
    :param stream: ストリーム応答であるか
    :param user: 一意のユーザーID
    :return: 完全な応答またはストリーム応答チャンクジェネレータの結果
    """

コードを実装する際には、同期応答とストリーミング応答をそれぞれ処理するために、2つの関数を使用してデータを返す必要があることに注意してください。

Pythonは、関数内にyieldキーワードが含まれる関数をジェネレータ関数として認識し、返されるデータ型はGeneratorに固定されるため、同期応答とストリーミング応答をそれぞれ実装する必要があります。例えば、以下のサンプルコードです:

この例では簡略化されたパラメータを使用しています。実際のコード作成時には、上記のパラメータリストを参照してください。

def _invoke(self, stream: bool, **kwargs) \
        -> Union[LLMResult, Generator]:
    if stream:
          return self._handle_stream_response(**kwargs)
    return self._handle_sync_response(**kwargs)

def _handle_stream_response(self, **kwargs) -> Generator:
    for chunk in response:
          yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
    return LLMResult(**response)
  • 入力トークンの事前計算

    モデルがトークンを事前計算するインターフェースを提供していない場合は、直接0を返すことができます。

def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
                 tools: Optional[list[PromptMessageTool]] = None) -> int:
  """
  指定されたプロンプトメッセージのトークン数を取得する

  :param model: モデル名
  :param credentials: モデルの認証情報
  :param prompt_messages: プロンプトメッセージ
  :param tools: ツール呼び出し用のツール
  :return:
  """

場合によっては、直接0を返したくない場合は、self._get_num_tokens_by_gpt2(text: str)メソッドを使用してトークンを計算できます。このメソッドはAIModelベースクラスにあり、GPT-2のTokenizerを使用して計算します。ただし、これは代替案であり、計算結果にはある程度の誤差が生じる可能性があることに注意してください。

  • モデル認証情報検証

    プロバイダーの認証情報検証と同様に、ここでは個別のモデルに対して検証を行います。

def validate_credentials(self, model: str, credentials: dict) -> None:
    """
    モデルの認証情報を検証する

    :param model: モデル名
    :param credentials: モデルの認証情報
    :return:
    """
  • モデルパラメータスキーマ

    事前定義済みモデルタイプとは異なり、YAMLファイルでモデルがサポートするパラメータが事前設定されていないため、モデルパラメータのスキーマを動的に生成する必要があります。

    例えば、Xinferenceはmax_tokenstemperaturetop_pの3つのモデルパラメータをサポートしています。しかし、一部のプロバイダー(例えばOpenLLM)は、具体的なモデルによって異なるパラメータをサポートします。

    例を挙げると、プロバイダーOpenLLMのAモデルはtop_kパラメータをサポートしていますが、Bモデルはtop_kをサポートしていません。この場合、各モデルに対応するパラメータスキーマを動的に生成する必要があります。サンプルコードは以下の通りです:

  def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
      """
          カスタマイズ可能なモデルスキーマを定義するために使用されます
      """
      rules = [
          ParameterRule(
              name='temperature', type=ParameterType.FLOAT,
              use_template='temperature',
              label=I18nObject(
                  zh_Hans='温度', en_US='Temperature'
              )
          ),
          ParameterRule(
              name='top_p', type=ParameterType.FLOAT,
              use_template='top_p',
              label=I18nObject(
                  zh_Hans='Top P', en_US='Top P'
              )
          ),
          ParameterRule(
              name='max_tokens', type=ParameterType.INT,
              use_template='max_tokens',
              min=1,
              default=512,
              label=I18nObject(
                  zh_Hans='最大生成長', en_US='Max Tokens'
              )
          )
      ]

      # モデルがAの場合、top_kをルールに追加
      if model == 'A':
          rules.append(
              ParameterRule(
                  name='top_k', type=ParameterType.INT,
                  use_template='top_k',
                  min=1,
                  default=50,
                  label=I18nObject(
                      zh_Hans='Top K', en_US='Top K'
                  )
              )
          )

      """
          ここには重要でないコードがあります
      """

      entity = AIModelEntity(
          model=model,
          label=I18nObject(
              en_US=model,
              zh_Hans=model # 必要に応じて翻訳
          ),
          fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
          model_type=model_type,
          model_properties={
              ModelPropertyKey.MODE:  ModelType.LLM,
          },
          parameter_rules=rules
      )

      return entity
  • 呼び出し例外エラーマッピング表

    モデル呼び出しが例外をスローした場合、Runtimeが指定するInvokeErrorタイプにマッピングする必要があります。これにより、Difyが異なるエラーに対して異なる後続処理を行うのに便利です。

    Runtimeエラー:

    • InvokeConnectionError 呼び出し接続エラー
    • InvokeServerUnavailableError 呼び出しサービス利用不可
    • InvokeRateLimitError 呼び出しレート制限到達
    • InvokeAuthorizationError 呼び出し認証失敗
    • InvokeBadRequestError 呼び出しパラメータ不正
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
    """
    モデル呼び出しエラーを統一エラーにマッピングする
    キーは呼び出し元にスローされるエラータイプです
    値はモデルによってスローされるエラータイプであり、呼び出し元のために統一されたエラータイプに変換する必要があります。

    :return: 呼び出しエラーマッピング
    """

より多くのインターフェースメソッドについては、インターフェースドキュメント:Modelを参照してください。

本稿で扱った完全なコードファイルを入手するには、GitHubコードリポジトリにアクセスしてください。

3. プラグインのデバッグ

プラグインの開発が完了したら、次にプラグインが正常に動作するかをテストする必要があります。詳細については、以下を参照してください:

debug-plugin.md

4. プラグインの公開

プラグインをDify Marketplaceに公開したい場合は、以下の内容を参照してください:

publish-to-dify-marketplace

さらに探索

クイックスタート:

プラグインインターフェースドキュメント: