Comprehensive guide to the Dify model plugin API including implementation requirements for LLM, TextEmbedding, Rerank, Speech2text, and Text2speech models, with detailed specifications for all related data structures.
This document details the interfaces and data structures required to implement Dify model plugins. It serves as a technical reference for developers integrating AI models with the Dify platform.
def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials by making a test API call Parameters: credentials: Provider credentials as defined in `provider_credential_schema` Raises: CredentialsValidateFailedError: If validation fails """ try: # Example implementation - validate using an LLM model instance model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( model="example-model", credentials=credentials ) except Exception as ex: logger.exception(f"Credential validation failed") raise CredentialsValidateFailedError(f"Invalid credentials: {str(ex)}")
Credential information as defined in the provider’s YAML configuration under provider_credential_schema.
Typically includes fields like api_key, organization_id, etc.
If validation fails, your implementation must raise a CredentialsValidateFailedError exception. This ensures proper error handling in the Dify UI.
For predefined model providers, you should implement a thorough validation method that verifies the credentials work with your API. For custom model providers (where each model has its own credentials), a simplified implementation is sufficient.
Dify supports five distinct model types, each requiring implementation of specific interfaces. However, all model types share some common requirements.
def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate that the provided credentials work with the specified model Parameters: model: The specific model identifier (e.g., "gpt-4") credentials: Authentication details for the model Raises: CredentialsValidateFailedError: If validation fails """ try: # Make a lightweight API call to verify credentials # Example: List available models or check account status response = self._api_client.validate_api_key(credentials["api_key"]) # Verify the specific model is available if applicable if model not in response.get("available_models", []): raise CredentialsValidateFailedError(f"Model {model} is not available") except ApiException as e: raise CredentialsValidateFailedError(str(e))
You can alternatively raise these standardized error types directly in your code instead of relying on the error mapping. This approach gives you more control over error messages.
def get_num_tokens( self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate the number of tokens in the prompt """ # Convert prompt_messages to the format expected by the tokenizer text = self._convert_messages_to_text(prompt_messages) try: # Use the appropriate tokenizer for this model tokenizer = self._get_tokenizer(model) return len(tokenizer.encode(text)) except Exception: # Fall back to a generic tokenizer return self._get_num_tokens_by_gpt2(text)
If the model doesn’t provide a tokenizer, you can use the base class’s _get_num_tokens_by_gpt2(text) method for a reasonable approximation.
def get_customizable_model_schema( self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get parameter schema for custom models """ # For fine-tuned models, you might return the base model's schema if model.startswith("ft:"): base_model = self._extract_base_model(model) return self._get_predefined_model_schema(base_model) # For standard models, return None to use the predefined schema return None
This method is only necessary for providers that support custom models. It allows custom models to inherit parameter rules from base models.
Text embedding models convert text into high-dimensional vectors that capture semantic meaning, which is useful for retrieval, similarity search, and classification.
To implement a Text Embedding provider, inherit from the __base.text_embedding_model.TextEmbeddingModel base class:
def get_num_tokens( self, model: str, credentials: dict, texts: list[str]) -> int: """ Calculate the number of tokens in the texts to be embedded """ # Join all texts to estimate token count combined_text = " ".join(texts) try: # Use the appropriate tokenizer for this model tokenizer = self._get_tokenizer(model) return len(tokenizer.encode(combined_text)) except Exception: # Fall back to a generic tokenizer return self._get_num_tokens_by_gpt2(combined_text)
For embedding models, accurate token counting is important for cost estimation, but not critical for functionality. The _get_num_tokens_by_gpt2 method provides a reasonable approximation for most models.
Reranking models help improve search quality by re-ordering a set of candidate documents based on their relevance to a query, typically after an initial retrieval phase.
To implement a Reranking provider, inherit from the __base.rerank_model.RerankModel base class:
Copy
def _invoke( self, model: str, credentials: dict, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, user: Optional[str] = None) -> RerankResult: """ Rerank documents based on relevance to the query """ # Set up API client with credentials client = self._get_client(credentials) # Prepare request data request_data = { "query": query, "documents": docs, } # Call reranking API endpoint response = client.rerank( model=model, **request_data, user=user ) # Process results ranked_results = [] for i, result in enumerate(response.results): # Create RerankDocument for each result doc = RerankDocument( index=result.document_index, # Original index in docs list text=docs[result.document_index], # Original text score=result.relevance_score # Relevance score ) ranked_results.append(doc) # Sort by score in descending order ranked_results.sort(key=lambda x: x.score, reverse=True) # Apply score threshold filtering if specified if score_threshold is not None: ranked_results = [doc for doc in ranked_results if doc.score >= score_threshold] # Apply top_n limit if specified if top_n is not None and top_n > 0: ranked_results = ranked_results[:top_n] return RerankResult( model=model, docs=ranked_results )
docs: List of RerankDocument objects with index, text, and score
Reranking can be computationally expensive, especially with large document sets. Implement batching for large document collections to avoid timeouts or excessive resource consumption.
Speech-to-text models convert spoken language from audio files into written text, enabling applications like transcription services, voice commands, and accessibility features.
To implement a Speech-to-Text provider, inherit from the __base.speech2text_model.Speech2TextModel base class:
Copy
def _invoke( self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Convert speech audio to text """ # Set up API client with credentials client = self._get_client(credentials) try: # Determine the file format file_format = self._detect_audio_format(file) # Prepare the file for API submission # Most APIs require either a file path or binary data audio_data = file.read() # Call the speech-to-text API response = client.audio.transcriptions.create( model=model, file=("audio.mp3", audio_data), # Adjust filename based on actual format user=user ) # Extract and return the transcribed text return response.text except Exception as e: # Map to appropriate error type self._handle_api_error(e) finally: # Reset file pointer for potential reuse file.seek(0)
Audio format detection is important for proper handling of different file types. Consider implementing a helper method to detect the format from the file header as shown in the example.
Some speech-to-text APIs have file size limitations. Consider implementing chunking for large audio files if necessary.
Text-to-speech models convert written text into natural-sounding speech, enabling applications such as voice assistants, screen readers, and audio content generation.
To implement a Text-to-Speech provider, inherit from the __base.text2speech_model.Text2SpeechModel base class:
Copy
def _invoke( self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> Union[bytes, Generator[bytes, None, None]]: """ Convert text to speech audio """ # Set up API client with credentials client = self._get_client(credentials) # Get voice settings based on model voice = self._get_voice_for_model(model) try: # Choose implementation based on streaming preference if streaming: return self._stream_audio( client=client, model=model, text=content_text, voice=voice, user=user ) else: return self._generate_complete_audio( client=client, model=model, text=content_text, voice=voice, user=user ) except Exception as e: self._handle_api_error(e)
Most text-to-speech APIs require you to specify a voice along with the model. Consider implementing a mapping between Dify’s model identifiers and the provider’s voice options.
Long text inputs may need to be chunked for better speech synthesis quality. Consider implementing text preprocessing to handle punctuation, numbers, and special characters properly.
Moderation models analyze content for potentially harmful, inappropriate, or unsafe material, helping maintain platform safety and content policies.
To implement a Moderation provider, inherit from the __base.moderation_model.ModerationModel base class:
Copy
def _invoke( self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Analyze text for harmful content Returns: bool: False if the text is safe, True if it contains harmful content """ # Set up API client with credentials client = self._get_client(credentials) try: # Call moderation API response = client.moderations.create( model=model, input=text, user=user ) # Check if any categories were flagged result = response.results[0] # Return True if flagged in any category, False if safe return result.flagged except Exception as e: # Log the error but default to safe if there's an API issue # This is a conservative approach - production systems might want # different fallback behavior logger.error(f"Moderation API error: {str(e)}") return False
Moderation is often used as a safety mechanism. Consider the implications of false negatives (letting harmful content through) versus false positives (blocking safe content) when implementing your solution.
Many moderation APIs provide detailed category scores rather than just a binary result. Consider extending this implementation to return more detailed information about specific categories of harmful content if your application needs it.
Message content base class, used only for parameter declaration, cannot be initialized.
Copy
class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ type: PromptMessageContentType data: str # Content data
Currently supports two types: text and images, and can support text and multiple images simultaneously.
You need to initialize TextPromptMessageContent and ImagePromptMessageContent separately.
class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ type: PromptMessageContentType = PromptMessageContentType.TEXT
When passing in text and images, text needs to be constructed as this entity as part of the content list.
class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ class DETAIL(Enum): LOW = 'low' HIGH = 'high' type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW # Resolution
When passing in text and images, images need to be constructed as this entity as part of the content list.
data can be a url or an image base64 encoded string.
Base class for all Role message bodies, used only for parameter declaration, cannot be initialized.
Copy
class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ role: PromptMessageRole # Message role content: Optional[str | list[PromptMessageContent]] = None # Supports two types: string and content list. The content list is for multimodal needs, see PromptMessageContent for details. name: Optional[str] = None # Name, optional.
Represents model response messages, typically used for few-shots or chat history input.
Copy
class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ name: str # Tool name arguments: str # Tool parameters id: str # Tool ID, only effective for OpenAI tool call, a unique ID for tool invocation, the same tool can be called multiple times type: str # Default is function function: ToolCallFunction # Tool call information role: PromptMessageRole = PromptMessageRole.ASSISTANT tool_calls: list[ToolCall] = [] # Model's tool call results (only returned when tools are passed in and the model decides to call them)
Here tool_calls is the list of tool call returned by the model after passing in tools to the model.
Represents tool messages, used to pass results to the model for next-step planning after a tool has been executed.
Copy
class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str # Tool call ID, if OpenAI tool call is not supported, you can also pass in the tool name
The base class’s content passes in the tool execution result.
class LLMResult(BaseModel): """ Model class for llm result. """ model: str # Actually used model prompt_messages: list[PromptMessage] # Prompt message list message: AssistantPromptMessage # Reply message usage: LLMUsage # Tokens used and cost information system_fingerprint: Optional[str] = None # Request fingerprint, refer to OpenAI parameter definition
Delta entity within each iteration in streaming response
Copy
class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ index: int # Sequence number message: AssistantPromptMessage # Reply message usage: Optional[LLMUsage] = None # Tokens used and cost information, only returned in the last message finish_reason: Optional[str] = None # Completion reason, only returned in the last message
class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ model: str # Actually used model prompt_messages: list[PromptMessage] # Prompt message list system_fingerprint: Optional[str] = None # Request fingerprint, refer to OpenAI parameter definition delta: LLMResultChunkDelta # Changes in content for each iteration
class LLMUsage(ModelUsage): """ Model class for llm usage. """ prompt_tokens: int # Tokens used by prompt prompt_unit_price: Decimal # Prompt unit price prompt_price_unit: Decimal # Prompt price unit, i.e., unit price based on how many tokens prompt_price: Decimal # Prompt cost completion_tokens: int # Tokens used by completion completion_unit_price: Decimal # Completion unit price completion_price_unit: Decimal # Completion price unit, i.e., unit price based on how many tokens completion_price: Decimal # Completion cost total_tokens: int # Total tokens used total_price: Decimal # Total cost currency: str # Currency unit latency: float # Request time (s)
class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ model: str # Actually used model embeddings: list[list[float]] # Embedding vector list, corresponding to the input texts list usage: EmbeddingUsage # Usage information
class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ tokens: int # Tokens used total_tokens: int # Total tokens used unit_price: Decimal # Unit price price_unit: Decimal # Price unit, i.e., unit price based on how many tokens total_price: Decimal # Total cost currency: str # Currency unit latency: float # Request time (s)
class RerankResult(BaseModel): """ Model class for rerank result. """ model: str # Actually used model docs: list[RerankDocument] # List of reranked segments
class RerankDocument(BaseModel): """ Model class for rerank document. """ index: int # Original sequence number text: str # Segment text content score: float # Score