イントロダクション
ベンダー統合が完了した後、次にベンダーの下でモデルのインテグレーションを行います。ここでは、全体のプロセスを理解するために、例としてXinference
を使用して、段階的にベンダーのインテグレーションを完了します。
注意が必要なのは、カスタムモデルの場合、各モデルのインテグレーションには完全なベンダークレデンシャルの記入が必要です。
事前定義モデルとは異なり、カスタムベンダーのインテグレーション時には常に以下の2つのパラメータが存在し、ベンダー yaml に定義する必要はありません。
前述したように、ベンダーはvalidate_provider_credential
を実装する必要はなく、Runtimeがユーザーが選択したモデルタイプとモデル名に基づいて、対応するモデル層のvalidate_credentials
を呼び出して検証を行います。
ベンダー yaml の作成
まず、インテグレーションを行うベンダーがどのタイプのモデルをサポートしているかを確認します。
現在サポートされているモデルタイプは以下の通りです:
text_embedding
テキスト Embedding モデル
Xinference
はLLM
、Text Embedding
、Rerank
をサポートしているため、xinference.yaml
を作成します。
Copy provider : xinference # ベンダー識別子
label : # ベンダー表示名、en_US 英語、zh_Hans 中国語の両方の言語で設定可能、zh_Hans が設定されていない場合は en_US がデフォルト
en_US : Xorbits Inference
icon_small : # 小アイコン、他のベンダーのアイコンを参考にし、対応するベンダー実装ディレクトリの _assets ディレクトリに保存
en_US : icon_s_en.svg
icon_large : # 大アイコン
en_US : icon_l_en.svg
help : # ヘルプ
title :
en_US : How to deploy Xinference
zh_Hans : 如何部署 Xinference
url :
en_US : https://github.com/xorbitsai/inference
supported_model_types : # サポートされるモデルタイプ、XinferenceはLLM/Text Embedding/Rerankをサポート
- llm
- text-embedding
- rerank
configurate_methods : # Xinferenceはローカルデプロイのベンダーであり、事前定義モデルがないため、必要なモデルを自分でデプロイする必要があるので、ここではカスタムモデルのみサポート
- customizable-model
provider_credential_schema :
credential_form_schemas :
その後、Xinferenceでモデルを定義するために必要なクレデンシャルを考えます。
3つの異なるモデルをサポートするため、model_type
を使用してこのモデルのタイプを指定する必要があります。3つのタイプがあるので、次のように記述します。
Copy provider_credential_schema :
credential_form_schemas :
- variable : model_type
type : select
label :
en_US : Model type
zh_Hans : 模型类型
required : true
options :
- value : text-generation
label :
en_US : Language Model
zh_Hans : 言語モデル
- value : embeddings
label :
en_US : Text Embedding
- value : reranking
label :
en_US : Rerank
各モデルには独自の名称model_name
があるため、ここで定義する必要があります。
Copy - variable : model_name
type : text-input
label :
en_US : Model name
zh_Hans : モデル名
required : true
placeholder :
zh_Hans : 填写模型名称
en_US : Input model name
Xinferenceのローカルデプロイのアドレスを記入します。
Copy - variable : server_url
label :
zh_Hans : 服务器URL
en_US : Server url
type : text-input
required : true
placeholder :
zh_Hans : 在此输入Xinference的服务器地址,如 https://example.com/xxx
en_US : Enter the url of your Xinference, for example https://example.com/xxx
各モデルには一意の model_uid があるため、ここで定義する必要があります。
Copy - variable : model_uid
label :
zh_Hans : 模型 UID
en_US : Model uid
type : text-input
required : true
placeholder :
zh_Hans : 在此输入您的 Model UID
en_US : Enter the model uid
これで、ベンダーの基本定義が完了しました。
モデルコードの作成
次に、llm
タイプを例にとって、xinference.llm.llm.py
を作成します。
llm.py
内で、Xinference LLM クラスを作成し、XinferenceAILargeLanguageModel
(任意の名前)と名付けて、__base.large_language_model.LargeLanguageModel
基底クラスを継承し、以下のメソッドを実装します:
LLM 呼び出し
LLM 呼び出しのコアメソッドを実装し、ストリームレスポンスと同期レスポンスの両方をサポートします。
Copy def _invoke ( self , model : str , credentials : dict ,
prompt_messages : list [ PromptMessage ], model_parameters : dict ,
tools : Optional [ list [ PromptMessageTool ]] = None , stop : Optional [ List [ str ]] = None ,
stream : bool = True , user : Optional [ str ] = None ) \
-> Union [ LLMResult , Generator ] :
"""
Invoke large language model
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:return: full response or stream response chunk generator result
"""
実装時には、同期レスポンスとストリームレスポンスを処理するために2つの関数を使用してデータを返す必要があります。Pythonはyield
キーワードを含む関数をジェネレータ関数として認識し、返されるデータ型は固定でジェネレーターになります。そのため、同期レスポンスとストリームレスポンスは別々に実装する必要があります。以下のように実装します(例では簡略化されたパラメータを使用していますが、実際の実装では上記のパラメータリストに従って実装してください):
Copy def _invoke ( self , stream : bool , ** kwargs ) \
-> Union [ LLMResult , Generator ] :
if stream :
return self . _handle_stream_response ( ** kwargs)
return self . _handle_sync_response ( ** kwargs)
def _handle_stream_response ( self , ** kwargs ) -> Generator:
for chunk in response :
yield chunk
def _handle_sync_response ( self , ** kwargs ) -> LLMResult:
return LLMResult ( ** response)
予測トークン数の計算
モデルが予測トークン数の計算インターフェースを提供していない場合、直接0を返すことができます。
Copy def get_num_tokens ( self , model : str , credentials : dict , prompt_messages : list [ PromptMessage ],
tools : Optional [ list [ PromptMessageTool ]] = None ) -> int :
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:return:
"""
時には、直接0を返す必要がない場合もあります。その場合はself._get_num_tokens_by_gpt2(text: str)
を使用して予測トークン数を取得することができます。このメソッドはAIModel
基底クラスにあり、GPT2のTokenizerを使用して計算を行いますが、代替方法として使用されるものであり、完全に正確ではありません。
モデルクレデンシャル検証
ベンダークレデンシャル検証と同様に、ここでは個々のモデルについて検証を行います。
Copy def validate_credentials ( self , model : str , credentials : dict ) -> None :
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
モデルパラメータスキーマ
カスタムタイプとは異なり、yamlファイルでモデルがサポートするパラメータを定義していないため、動的にモデルパラメータのスキーマを生成する必要があります。
例えば、Xinferenceはmax_tokens
、temperature
、top_p
の3つのモデルパラメータをサポートしています。
しかし、ベンダーによっては異なるモデルに対して異なるパラメータをサポートしている場合があります。例えば、ベンダーOpenLLM
はtop_k
をサポートしていますが、全てのモデルがtop_k
をサポートしているわけではありません。ここでは、例としてAモデルがtop_k
をサポートし、Bモデルがtop_k
をサポートしていない場合、以下のように動的にモデルパラメータのスキーマを生成します:
Copy def get_customizable_model_schema ( self , model : str , credentials : dict ) -> AIModelEntity | None :
"""
used to define customizable model schema
"""
rules = [
ParameterRule (
name = 'temperature' , type = ParameterType.FLOAT,
use_template = 'temperature' ,
label = I18nObject (
zh_Hans = '温度' , en_US = 'Temperature'
)
),
ParameterRule (
name = 'top_p' , type = ParameterType.FLOAT,
use_template = 'top_p' ,
label = I18nObject (
zh_Hans = 'Top P' , en_US = 'Top P'
)
),
ParameterRule (
name = 'max_tokens' , type = ParameterType.INT,
use_template = 'max_tokens' ,
min = 1 ,
default = 512 ,
label = I18nObject (
zh_Hans = '最大生成长度' , en_US = 'Max Tokens'
)
)
]
# if model is A, add top_k to rules
if model == 'A' :
rules . append (
ParameterRule (
name = 'top_k' , type = ParameterType.INT,
use_template = 'top_k' ,
min = 1 ,
default = 50 ,
label = I18nObject (
zh_Hans = 'Top K' , en_US = 'Top K'
)
)
)
"""
some NOT IMPORTANT code here
"""
entity = AIModelEntity (
model = model,
label = I18nObject (
en_US = model
),
fetch_from = FetchFrom.CUSTOMIZABLE_MODEL,
model_type = model_type,
model_properties = {
ModelPropertyKey.MODE: ModelType.LLM,
},
parameter_rules = rules
)
return entity
呼び出しエラーマッピングテーブル
モデル呼び出し時にエラーが発生した場合、Runtimeが指定するInvokeError
タイプにマッピングする必要があります。これにより、Difyは異なるエラーに対して異なる後続処理を行うことができます。
Runtime エラー:
InvokeConnectionError
呼び出し接続エラー
InvokeServerUnavailableError
呼び出しサービスが利用不可
InvokeRateLimitError
呼び出し回数制限に達した
InvokeAuthorizationError
認証エラー
InvokeBadRequestError
不正なリクエストパラメータ
Copy @ property
def _invoke_error_mapping ( self ) -> dict [ type [ InvokeError ], list [ type [ Exception ]]] :
"""
Map model invoke error to uni ```
-> dict[type[呼び出しエラー], list[type[例外]]]: """ モデル呼び出しエラーを統一エラーにマッピングする キーは呼び出し元に投げられるエラータイプ バリューはモデルが投げるエラータイプであり、 呼び出し元に対して統一エラータイプに変換する必要があります。
Copy :return: 呼び出しエラーのマッピング
"""
Copy
インターフェース方法の詳細については:[インターフェース](https://github.com/langgenius/dify/blob/main/api/core/model_runtime/docs/zh_Hans/interfaces.md)をご覧ください。具体的な実装例については、[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)を参照してください。