A custom model refers to an LLM that you deploy or configure on your own. This document uses the Xinference model as an example to demonstrate how to integrate a custom model into your model plugin.
By default, a custom model automatically includes two parameters—its model type and model name—and does not require additional definitions in the provider YAML file.
You do not need to implement validate_provider_credential in your provider configuration file. During runtime, based on the user’s choice of model type or model name, Dify automatically calls the corresponding model layer’s validate_credentials method to verify credentials.
Create a Model Provider File
Identify the model types your custom model will include.
Create Code Files by Model Type
Depending on the model’s type (e.g., llm or text_embedding), create separate code files. Ensure that each model type is organized into distinct logical layers for easier maintenance and future expansion.
Develop the Model Invocation Logic
Within each model-type module, create a Python file named for that model type (for example, llm.py). Define a class in the file that implements the specific model logic, conforming to the system’s model interface specifications.
Debug the Plugin
Write unit and integration tests for the new provider functionality, ensuring that all components work as intended.
In your plugin’s /provider directory, create a xinference.yaml file.
The Xinference family of models supports LLM, Text Embedding, and Rerank model types, so your xinference.yaml must include all three.
Example:
Copy
Ask AI
provider: xinference # Identifies the providerlabel: # Display name; can set both en_US (English) and zh_Hans (Chinese). If zh_Hans is not set, en_US is used by default. en_US: Xorbits Inferenceicon_small: # Small icon; store in the _assets folder of this provider’s directory. The same multi-language logic applies as with label. en_US: icon_s_en.svgicon_large: # Large icon en_US: icon_l_en.svghelp: # Help information title: en_US: How to deploy Xinference zh_Hans: 如何部署 Xinference url: en_US: https://github.com/xorbitsai/inferencesupported_model_types: # Model types Xinference supports: LLM/Text Embedding/Rerank- llm- text-embedding- rerankconfigurate_methods: # Xinference is locally deployed and does not offer predefined models. Refer to its documentation to learn which model to use. Thus, we choose a customizable-model approach.- customizable-modelprovider_credential_schema: credential_form_schemas:
Next, define the provider_credential_schema. Since Xinference supports text-generation, embeddings, and reranking models, you can configure it as follows:
Copy
Ask AI
provider_credential_schema: credential_form_schemas: - variable: model_type type: select label: en_US: Model type zh_Hans: 模型类型 required: true options: - value: text-generation label: en_US: Language Model zh_Hans: 语言模型 - value: embeddings label: en_US: Text Embedding - value: reranking label: en_US: Rerank
Every model in Xinference requires a model_name:
Copy
Ask AI
- variable: model_name type: text-input label: en_US: Model name zh_Hans: 模型名称 required: true placeholder: zh_Hans: 填写模型名称 en_US: Input model name
Because Xinference must be locally deployed, users need to supply the server address (server_url) and model UID. For instance:
Copy
Ask AI
- variable: server_url label: zh_Hans: 服务器 URL en_US: Server url type: text-input required: true placeholder: zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx en_US: Enter the url of your Xinference, for example https://example.com/xxx - variable: model_uid label: zh_Hans: 模型 UID en_US: Model uid type: text-input required: true placeholder: zh_Hans: 在此输入您的 Model UID en_US: Enter the model uid
Once you’ve defined these parameters, the YAML configuration for your custom model provider is complete. Next, create the functional code files for each model defined in this config.
Since Xinference supports llm, rerank, speech2text, and tts, you should create corresponding directories under /models, each containing its respective feature code.
Below is an example for an llm-type model. You’d create a file named llm.py, then define a class—such as XinferenceAILargeLanguageModel—that extends __base.large_language_model.LargeLanguageModel. This class should include:
LLM Invocation
The core method for invoking the LLM, supporting both streaming and synchronous responses:
Copy
Ask AI
def _invoke( self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke the large language model. :param model: model name :param credentials: model credentials :param prompt_messages: prompt messages :param model_parameters: model parameters :param tools: tools for tool calling :param stop: stop words :param stream: determines if response is streamed :param user: unique user id :return: full response or a chunk generator """
You’ll need two separate functions to handle streaming and synchronous responses. Python treats any function containing yield as a generator returning type Generator, so it’s best to split them:
Copy
Ask AI
def _invoke(self, stream: bool, **kwargs) -> Union[LLMResult, Generator]: if stream: return self._handle_stream_response(**kwargs) return self._handle_sync_response(**kwargs)def _handle_stream_response(self, **kwargs) -> Generator: for chunk in response: yield chunkdef _handle_sync_response(self, **kwargs) -> LLMResult: return LLMResult(**response)
Pre-calculating Input Tokens
If your model doesn’t provide a token-counting interface, simply return 0:
Copy
Ask AI
def get_num_tokens( self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Get the number of tokens for the given prompt messages. """ return 0
Alternatively, you can call self._get_num_tokens_by_gpt2(text: str) from the AIModel base class, which uses a GPT-2 tokenizer. Remember this is an approximation and may not match your model exactly.
Validating Model Credentials
Similar to provider-level credential checks, but scoped to a single model:
Unlike predefined models, no YAML is defining which parameters a model supports. You must generate a parameter schema dynamically.
For example, Xinference supports max_tokens, temperature, and top_p. Some other providers (e.g., OpenLLM) may support parameters like top_k only for certain models. This means you need to adapt your schema to each model’s capabilities:
Copy
Ask AI
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema """ rules = [ ParameterRule( name='temperature', type=ParameterType.FLOAT, use_template='temperature', label=I18nObject( zh_Hans='温度', en_US='Temperature' ) ), ParameterRule( name='top_p', type=ParameterType.FLOAT, use_template='top_p', label=I18nObject( zh_Hans='Top P', en_US='Top P' ) ), ParameterRule( name='max_tokens', type=ParameterType.INT, use_template='max_tokens', min=1, default=512, label=I18nObject( zh_Hans='最大生成长度', en_US='Max Tokens' ) ) ] # if model is A, add top_k to rules if model == 'A': rules.append( ParameterRule( name='top_k', type=ParameterType.INT, use_template='top_k', min=1, default=50, label=I18nObject( zh_Hans='Top K', en_US='Top K' ) ) ) """ some NOT IMPORTANT code here """ entity = AIModelEntity( model=model, label=I18nObject( en_US=model ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=model_type, model_properties={ ModelPropertyKey.MODE: ModelType.LLM, }, parameter_rules=rules ) return entity
Error Mapping
When an error occurs during model invocation, map it to the appropriate InvokeError type recognized by the runtime. This lets Dify handle different errors in a standardized manner:
@propertydef _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: """ Map model invocation errors to unified error types. The key is the error type thrown to the caller. The value is the error type thrown by the model, which needs to be mapped to a unified Dify error for consistent handling. """ # return { # InvokeConnectionError: [requests.exceptions.ConnectionError], # ... # }