カスタムモデルとは、自身でデプロイまたは設定する必要があるLLMを指します。本稿では、Xinferenceモデルを例に、モデルプラグイン内でカスタムモデルを統合する方法を説明します。
カスタムモデルはデフォルトでモデルタイプとモデル名の2つのパラメータを含んでおり、プロバイダーのYAMLファイルで定義する必要はありません。
プロバイダー設定ファイルでは、validate_provider_credential
を実装する必要はありません。Runtimeは、ユーザーが選択したモデルタイプまたはモデル名に基づいて、対応するモデルレイヤーのvalidate_credentials
メソッドを自動的に呼び出して検証を行います。
カスタムモデルプラグインの統合
カスタムモデルの統合は、以下のステップに分かれます:
-
モデルプロバイダーファイルの作成
カスタムモデルに含まれるモデルタイプを明確にします。
-
モデルタイプに基づいたコードファイルの作成
モデルのタイプ(例:llm
またはtext_embedding
)に基づいてコードファイルを作成します。各モデルタイプが独立したロジックレイヤーを持つようにし、保守と拡張を容易にします。
-
異なるモデルモジュールに基づいたモデル呼び出しコードの記述
対応するモデルタイプモジュールの下に、モデルタイプと同名のPythonファイル(例:llm.py)を作成します。ファイル内で具体的なモデルロジックを実装するクラスを定義し、そのクラスはシステムのモデルインターフェース仕様に準拠する必要があります。
-
プラグインのデバッグ
新たに追加されたプロバイダー機能に対して単体テストと結合テストを作成し、すべての機能モジュールが期待通りに動作することを確認します。
1. モデルプロバイダーファイルの作成
プラグインプロジェクトの/provider
パスの下に、新しいxinference.yaml
ファイルを作成します。
Xinference
ファミリーモデルはLLM
、Text Embedding
、Rerank
モデルタイプをサポートしているため、xinference.yaml
ファイルにこれらのモデルタイプを含める必要があります。
サンプルコード:
provider: xinference # プロバイダー識別子を決定
label: # プロバイダー表示名。en_US (英語)、zh_Hans (中国語)の2言語を設定可能。zh_Hans を設定しない場合はデフォルトで en_US が使用されます。
en_US: Xorbits Inference
zh_Hans: Xorbits Inference # 中国語の表示名(例として英語と同じにしていますが、通常は中国語訳)
icon_small: # 小アイコン。他のプロバイダーのアイコンを参考に、対応するプロバイダー実装ディレクトリ下の _assets ディレクトリに保存。中英ポリシーは label と同様。
en_US: icon_s_en.svg
icon_large: # 大アイコン
en_US: icon_l_en.svg
help: # ヘルプ
title:
en_US: How to deploy Xinference
zh_Hans: Xinferenceのデプロイ方法
url:
en_US: https://github.com/xorbitsai/inference
supported_model_types: # サポートされるモデルタイプ。XinferenceはLLM/Text Embedding/Rerankを同時にサポートします。
- llm
- text-embedding
- rerank
configurate_methods: # Xinferenceはローカルデプロイのプロバイダーであり、事前定義されたモデルはありません。どのモデルを使用するかはXinferenceのドキュメントに従ってデプロイする必要があるため、ここのメソッドはカスタムモデルです。
- customizable-model
provider_credential_schema:
credential_form_schemas:
次に、provider_credential_schema
フィールドを定義する必要があります。Xinference
はtext-generation
、embeddings
、reranking
モデルをサポートしています。サンプルコードは以下の通りです:
provider_credential_schema:
credential_form_schemas:
- variable: model_type
type: select
label:
en_US: Model type
zh_Hans: モデルタイプ
required: true
options:
- value: text-generation
label:
en_US: Language Model
zh_Hans: 言語モデル
- value: embeddings
label:
en_US: Text Embedding
zh_Hans: テキスト埋め込み # 中国語の表示名(例)
- value: reranking
label:
en_US: Rerank
zh_Hans: リランク # 中国語の表示名(例)
Xinferenceの各モデルでは、名前model_name
を定義する必要があります。
- variable: model_name
type: text-input
label:
en_US: Model name
zh_Hans: モデル名
required: true
placeholder:
zh_Hans: モデル名を入力してください
en_US: Input model name
Xinferenceモデルでは、ユーザーがモデルのローカルデプロイアドレスを入力する必要があります。プラグイン内では、Xinferenceモデルのローカルデプロイアドレス(server_url)とモデルUIDを入力できる場所を提供する必要があります。サンプルコードは以下の通りです:
- variable: server_url
label:
zh_Hans: サーバーURL
en_US: Server url
type: text-input
required: true
placeholder:
zh_Hans: ここにXinferenceのサーバーアドレスを入力してください(例:https://example.com/xxx)
en_US: Enter the url of your Xinference, for example https://example.com/xxx
- variable: model_uid
label:
zh_Hans: モデルUID
en_US: Model uid
type: text-input
required: true
placeholder:
zh_Hans: ここにあなたのモデルUIDを入力してください
en_US: Enter the model uid
すべてのパラメータを入力すると、カスタムモデルプロバイダーのYAML設定ファイルの作成が完了します。次に、設定ファイル内で定義されたモデルに具体的な機能コードファイルを追加する必要があります。
2. モデルコードの記述
Xinferenceモデルプロバイダーのモデルタイプには、llm、rerank、speech2text、ttsタイプが含まれるため、/models
パスの下に各モデルタイプごとに独立したグループを作成し、対応する機能コードファイルを作成する必要があります。
以下では、llmタイプを例に、llm.py
コードファイルの作成方法を説明します。コード作成時には、XinferenceAILargeLanguageModel
という名前のXinference LLMクラスを作成し、__base.large_language_model.LargeLanguageModel
ベースクラスを継承し、以下のいくつかのメソッドを実装する必要があります:
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) \
-> Union[LLMResult, Generator]:
"""
大規模言語モデルを呼び出す
:param model: モデル名
:param credentials: モデルの認証情報
:param prompt_messages: プロンプトメッセージ
:param model_parameters: モデルパラメータ
:param tools: ツール呼び出し用のツール
:param stop: ストップワード
:param stream: ストリーム応答であるか
:param user: 一意のユーザーID
:return: 完全な応答またはストリーム応答チャンクジェネレータの結果
"""
コードを実装する際には、同期応答とストリーミング応答をそれぞれ処理するために、2つの関数を使用してデータを返す必要があることに注意してください。
Pythonは、関数内にyield
キーワードが含まれる関数をジェネレータ関数として認識し、返されるデータ型はGenerator
に固定されるため、同期応答とストリーミング応答をそれぞれ実装する必要があります。例えば、以下のサンプルコードです:
この例では簡略化されたパラメータを使用しています。実際のコード作成時には、上記のパラメータリストを参照してください。
def _invoke(self, stream: bool, **kwargs) \
-> Union[LLMResult, Generator]:
if stream:
return self._handle_stream_response(**kwargs)
return self._handle_sync_response(**kwargs)
def _handle_stream_response(self, **kwargs) -> Generator:
for chunk in response:
yield chunk
def _handle_sync_response(self, **kwargs) -> LLMResult:
return LLMResult(**response)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
指定されたプロンプトメッセージのトークン数を取得する
:param model: モデル名
:param credentials: モデルの認証情報
:param prompt_messages: プロンプトメッセージ
:param tools: ツール呼び出し用のツール
:return:
"""
場合によっては、直接0を返したくない場合は、self._get_num_tokens_by_gpt2(text: str)
メソッドを使用してトークンを計算できます。このメソッドはAIModel
ベースクラスにあり、GPT-2のTokenizerを使用して計算します。ただし、これは代替案であり、計算結果にはある程度の誤差が生じる可能性があることに注意してください。
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
モデルの認証情報を検証する
:param model: モデル名
:param credentials: モデルの認証情報
:return:
"""
-
モデルパラメータスキーマ
事前定義済みモデルタイプとは異なり、YAMLファイルでモデルがサポートするパラメータが事前設定されていないため、モデルパラメータのスキーマを動的に生成する必要があります。
例えば、Xinferenceはmax_tokens
、temperature
、top_p
の3つのモデルパラメータをサポートしています。しかし、一部のプロバイダー(例えばOpenLLM)は、具体的なモデルによって異なるパラメータをサポートします。
例を挙げると、プロバイダーOpenLLM
のAモデルはtop_k
パラメータをサポートしていますが、Bモデルはtop_k
をサポートしていません。この場合、各モデルに対応するパラメータスキーマを動的に生成する必要があります。サンプルコードは以下の通りです:
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
カスタマイズ可能なモデルスキーマを定義するために使用されます
"""
rules = [
ParameterRule(
name='temperature', type=ParameterType.FLOAT,
use_template='temperature',
label=I18nObject(
zh_Hans='温度', en_US='Temperature'
)
),
ParameterRule(
name='top_p', type=ParameterType.FLOAT,
use_template='top_p',
label=I18nObject(
zh_Hans='Top P', en_US='Top P'
)
),
ParameterRule(
name='max_tokens', type=ParameterType.INT,
use_template='max_tokens',
min=1,
default=512,
label=I18nObject(
zh_Hans='最大生成長', en_US='Max Tokens'
)
)
]
# モデルがAの場合、top_kをルールに追加
if model == 'A':
rules.append(
ParameterRule(
name='top_k', type=ParameterType.INT,
use_template='top_k',
min=1,
default=50,
label=I18nObject(
zh_Hans='Top K', en_US='Top K'
)
)
)
"""
ここには重要でないコードがあります
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model,
zh_Hans=model # 必要に応じて翻訳
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=model_type,
model_properties={
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules=rules
)
return entity
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
モデル呼び出しエラーを統一エラーにマッピングする
キーは呼び出し元にスローされるエラータイプです
値はモデルによってスローされるエラータイプであり、呼び出し元のために統一されたエラータイプに変換する必要があります。
:return: 呼び出しエラーマッピング
"""
より多くのインターフェースメソッドについては、インターフェースドキュメント:Modelを参照してください。
本稿で扱った完全なコードファイルを入手するには、GitHubコードリポジトリにアクセスしてください。
3. プラグインのデバッグ
プラグインの開発が完了したら、次にプラグインが正常に動作するかをテストする必要があります。詳細については、以下を参照してください:
debug-plugin.md
4. プラグインの公開
プラグインをDify Marketplaceに公開したい場合は、以下の内容を参照してください:
publish-to-dify-marketplace
さらに探索
クイックスタート:
プラグインインターフェースドキュメント:
このページを編集する | 問題を報告する