mirror of
https://github.com/langgenius/dify-docs.git
synced 2026-03-27 13:28:32 +07:00
365 lines
12 KiB
Plaintext
365 lines
12 KiB
Plaintext
---
|
||
dimensions:
|
||
type:
|
||
primary: implementation
|
||
detail: advanced
|
||
level: advanced
|
||
standard_title: Customizable Model
|
||
language: en
|
||
title: Integrating Custom Models
|
||
description: This document details how to integrate custom models into Dify, using
|
||
the Xinference model as an example. It covers the complete process, including creating
|
||
model provider files, writing code based on model type, implementing model invocation
|
||
logic, handling exceptions, debugging, and publishing. It specifically details the
|
||
implementation of core methods like LLM invocation, token calculation, credential
|
||
validation, and parameter generation.
|
||
---
|
||
|
||
A **custom model** refers to an LLM that you deploy or configure on your own. This document uses the [Xinference model](https://inference.readthedocs.io/en/latest/) as an example to demonstrate how to integrate a custom model into your **model plugin**.
|
||
|
||
By default, a custom model automatically includes two parameters—its **model type** and **model name**—and does not require additional definitions in the provider YAML file.
|
||
|
||
You do not need to implement `validate_provider_credential` in your provider configuration file. During runtime, based on the user’s choice of model type or model name, Dify automatically calls the corresponding model layer’s `validate_credentials` method to verify credentials.
|
||
|
||
## Integrating a Custom Model Plugin
|
||
|
||
Below are the steps to integrate a custom model:
|
||
|
||
1. **Create a Model Provider File**\
|
||
Identify the model types your custom model will include.
|
||
2. **Create Code Files by Model Type**\
|
||
Depending on the model’s type (e.g., `llm` or `text_embedding`), create separate code files. Ensure that each model type is organized into distinct logical layers for easier maintenance and future expansion.
|
||
3. **Develop the Model Invocation Logic**\
|
||
Within each model-type module, create a Python file named for that model type (for example, `llm.py`). Define a class in the file that implements the specific model logic, conforming to the system’s model interface specifications.
|
||
4. **Debug the Plugin**\
|
||
Write unit and integration tests for the new provider functionality, ensuring that all components work as intended.
|
||
|
||
***
|
||
|
||
### 1. **Create a Model Provider File**
|
||
|
||
In your plugin’s `/provider` directory, create a `xinference.yaml` file.
|
||
|
||
The `Xinference` family of models supports **LLM**, **Text Embedding**, and **Rerank** model types, so your `xinference.yaml` must include all three.
|
||
|
||
**Example:**
|
||
|
||
```yaml
|
||
provider: xinference # Identifies the provider
|
||
label: # Display name; can set both en_US (English) and zh_Hans (Chinese). If zh_Hans is not set, en_US is used by default.
|
||
en_US: Xorbits Inference
|
||
icon_small: # Small icon; store in the _assets folder of this provider’s directory. The same multi-language logic applies as with label.
|
||
en_US: icon_s_en.svg
|
||
icon_large: # Large icon
|
||
en_US: icon_l_en.svg
|
||
help: # Help information
|
||
title:
|
||
en_US: How to deploy Xinference
|
||
zh_Hans: 如何部署 Xinference
|
||
url:
|
||
en_US: https://github.com/xorbitsai/inference
|
||
|
||
supported_model_types: # Model types Xinference supports: LLM/Text Embedding/Rerank
|
||
- llm
|
||
- text-embedding
|
||
- rerank
|
||
|
||
configurate_methods: # Xinference is locally deployed and does not offer predefined models. Refer to its documentation to learn which model to use. Thus, we choose a customizable-model approach.
|
||
- customizable-model
|
||
|
||
provider_credential_schema:
|
||
credential_form_schemas:
|
||
```
|
||
|
||
Next, define the `provider_credential_schema`. Since `Xinference` supports text-generation, embeddings, and reranking models, you can configure it as follows:
|
||
|
||
```yaml
|
||
provider_credential_schema:
|
||
credential_form_schemas:
|
||
- variable: model_type
|
||
type: select
|
||
label:
|
||
en_US: Model type
|
||
zh_Hans: 模型类型
|
||
required: true
|
||
options:
|
||
- value: text-generation
|
||
label:
|
||
en_US: Language Model
|
||
zh_Hans: 语言模型
|
||
- value: embeddings
|
||
label:
|
||
en_US: Text Embedding
|
||
- value: reranking
|
||
label:
|
||
en_US: Rerank
|
||
```
|
||
|
||
Every model in Xinference requires a `model_name`:
|
||
|
||
```yaml
|
||
- variable: model_name
|
||
type: text-input
|
||
label:
|
||
en_US: Model name
|
||
zh_Hans: 模型名称
|
||
required: true
|
||
placeholder:
|
||
zh_Hans: 填写模型名称
|
||
en_US: Input model name
|
||
```
|
||
|
||
Because Xinference must be locally deployed, users need to supply the server address (server\_url) and model UID. For instance:
|
||
|
||
```yaml
|
||
- variable: server_url
|
||
label:
|
||
zh_Hans: 服务器 URL
|
||
en_US: Server url
|
||
type: text-input
|
||
required: true
|
||
placeholder:
|
||
zh_Hans: 在此输入 Xinference 的服务器地址,如 https://example.com/xxx
|
||
en_US: Enter the url of your Xinference, for example https://example.com/xxx
|
||
|
||
- variable: model_uid
|
||
label:
|
||
zh_Hans: 模型 UID
|
||
en_US: Model uid
|
||
type: text-input
|
||
required: true
|
||
placeholder:
|
||
zh_Hans: 在此输入您的 Model UID
|
||
en_US: Enter the model uid
|
||
```
|
||
|
||
Once you’ve defined these parameters, the YAML configuration for your custom model provider is complete. Next, create the functional code files for each model defined in this config.
|
||
|
||
### 2. Develop the Model Code
|
||
|
||
Since Xinference supports llm, rerank, speech2text, and tts, you should create corresponding directories under /models, each containing its respective feature code.
|
||
|
||
Below is an example for an llm-type model. You’d create a file named llm.py, then define a class—such as XinferenceAILargeLanguageModel—that extends \_\_base.large\_language\_model.LargeLanguageModel. This class should include:
|
||
|
||
* **LLM Invocation**
|
||
|
||
The core method for invoking the LLM, supporting both streaming and synchronous responses:
|
||
|
||
```python
|
||
def _invoke(
|
||
self,
|
||
model: str,
|
||
credentials: dict,
|
||
prompt_messages: list[PromptMessage],
|
||
model_parameters: dict,
|
||
tools: Optional[list[PromptMessageTool]] = None,
|
||
stop: Optional[list[str]] = None,
|
||
stream: bool = True,
|
||
user: Optional[str] = None
|
||
) -> Union[LLMResult, Generator]:
|
||
"""
|
||
Invoke the large language model.
|
||
|
||
:param model: model name
|
||
:param credentials: model credentials
|
||
:param prompt_messages: prompt messages
|
||
:param model_parameters: model parameters
|
||
:param tools: tools for tool calling
|
||
:param stop: stop words
|
||
:param stream: determines if response is streamed
|
||
:param user: unique user id
|
||
:return: full response or a chunk generator
|
||
"""
|
||
```
|
||
|
||
You’ll need two separate functions to handle streaming and synchronous responses. Python treats any function containing `yield` as a generator returning type `Generator`, so it’s best to split them:
|
||
|
||
```yaml
|
||
def _invoke(self, stream: bool, **kwargs) -> Union[LLMResult, Generator]:
|
||
if stream:
|
||
return self._handle_stream_response(**kwargs)
|
||
return self._handle_sync_response(**kwargs)
|
||
|
||
def _handle_stream_response(self, **kwargs) -> Generator:
|
||
for chunk in response:
|
||
yield chunk
|
||
|
||
def _handle_sync_response(self, **kwargs) -> LLMResult:
|
||
return LLMResult(**response)
|
||
```
|
||
|
||
* **Pre-calculating Input Tokens**
|
||
|
||
If your model doesn’t provide a token-counting interface, simply return 0:
|
||
|
||
```python
|
||
def get_num_tokens(
|
||
self,
|
||
model: str,
|
||
credentials: dict,
|
||
prompt_messages: list[PromptMessage],
|
||
tools: Optional[list[PromptMessageTool]] = None
|
||
) -> int:
|
||
"""
|
||
Get the number of tokens for the given prompt messages.
|
||
"""
|
||
return 0
|
||
```
|
||
|
||
Alternatively, you can call `self._get_num_tokens_by_gpt2(text: str)` from the `AIModel` base class, which uses a GPT-2 tokenizer. Remember this is an approximation and may not match your model exactly.
|
||
|
||
* **Validating Model Credentials**
|
||
|
||
Similar to provider-level credential checks, but scoped to a single model:
|
||
|
||
```python
|
||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||
"""
|
||
Validate model credentials.
|
||
"""
|
||
```
|
||
|
||
* **Dynamic Model Parameters Schema**
|
||
|
||
Unlike [predefined models](/en/plugins/quick-start/develop-plugins/model-plugin/predefined-model), no YAML is defining which parameters a model supports. You must generate a parameter schema dynamically.
|
||
|
||
For example, Xinference supports `max_tokens`, `temperature`, and `top_p`. Some other providers (e.g., `OpenLLM`) may support parameters like `top_k` only for certain models. This means you need to adapt your schema to each model’s capabilities:
|
||
|
||
```python
|
||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||
"""
|
||
used to define customizable model schema
|
||
"""
|
||
rules = [
|
||
ParameterRule(
|
||
name='temperature', type=ParameterType.FLOAT,
|
||
use_template='temperature',
|
||
label=I18nObject(
|
||
zh_Hans='温度', en_US='Temperature'
|
||
)
|
||
),
|
||
ParameterRule(
|
||
name='top_p', type=ParameterType.FLOAT,
|
||
use_template='top_p',
|
||
label=I18nObject(
|
||
zh_Hans='Top P', en_US='Top P'
|
||
)
|
||
),
|
||
ParameterRule(
|
||
name='max_tokens', type=ParameterType.INT,
|
||
use_template='max_tokens',
|
||
min=1,
|
||
default=512,
|
||
label=I18nObject(
|
||
zh_Hans='最大生成长度', en_US='Max Tokens'
|
||
)
|
||
)
|
||
]
|
||
|
||
# if model is A, add top_k to rules
|
||
if model == 'A':
|
||
rules.append(
|
||
ParameterRule(
|
||
name='top_k', type=ParameterType.INT,
|
||
use_template='top_k',
|
||
min=1,
|
||
default=50,
|
||
label=I18nObject(
|
||
zh_Hans='Top K', en_US='Top K'
|
||
)
|
||
)
|
||
)
|
||
|
||
"""
|
||
some NOT IMPORTANT code here
|
||
"""
|
||
|
||
entity = AIModelEntity(
|
||
model=model,
|
||
label=I18nObject(
|
||
en_US=model
|
||
),
|
||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||
model_type=model_type,
|
||
model_properties={
|
||
ModelPropertyKey.MODE: ModelType.LLM,
|
||
},
|
||
parameter_rules=rules
|
||
)
|
||
|
||
return entity
|
||
```
|
||
|
||
* **Error Mapping**
|
||
|
||
When an error occurs during model invocation, map it to the appropriate InvokeError type recognized by the runtime. This lets Dify handle different errors in a standardized manner:
|
||
|
||
Runtime Errors:
|
||
|
||
```
|
||
• `InvokeConnectionError`
|
||
• `InvokeServerUnavailableError`
|
||
• `InvokeRateLimitError`
|
||
• `InvokeAuthorizationError`
|
||
• `InvokeBadRequestError`
|
||
```
|
||
|
||
```python
|
||
@property
|
||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||
"""
|
||
Map model invocation errors to unified error types.
|
||
The key is the error type thrown to the caller.
|
||
The value is the error type thrown by the model, which needs to be mapped to a
|
||
unified Dify error for consistent handling.
|
||
"""
|
||
# return {
|
||
# InvokeConnectionError: [requests.exceptions.ConnectionError],
|
||
# ...
|
||
# }
|
||
```
|
||
|
||
For more details on interface methods, see the [Model Documentation](https://docs.dify.ai/zh-hans/plugins/schema-definition/model).
|
||
|
||
To view the complete code files discussed in this guide, visit the [GitHub Repository](https://github.com/langgenius/dify-official-plugins/tree/main/models/xinference).
|
||
|
||
### 3. Debug the Plugin
|
||
|
||
After finishing development, test the plugin to ensure it runs correctly. For more details, refer to:
|
||
|
||
<Card title="Debug Plugin" icon="link" href="/en/plugins/quick-start/debug-plugin">
|
||
</Card>
|
||
|
||
### 4. Publish the Plugin
|
||
|
||
If you’d like to list this plugin on the Dify Marketplace, see:
|
||
|
||
Publish to Dify Marketplace
|
||
|
||
## Explore More
|
||
|
||
**Quick Start:**
|
||
|
||
* [Develop Extension Plugin](/en/plugins/quick-start/develop-plugins/extension-plugin)
|
||
* [Develop Tool Plugin](/en/plugins/quick-start/develop-plugins/tool-plugin)
|
||
* [Bundle Plugins: Package Multiple Plugins](/en/plugins/quick-start/develop-plugins/bundle)
|
||
|
||
**Plugins Endpoint Docs:**
|
||
|
||
* [Manifest](/en/plugins/schema-definition/manifest) Structure
|
||
* [Endpoint](/en/plugins/schema-definition/endpoint) Definitions
|
||
* [Reverse-Invocation of the Dify Service](/en/plugins/schema-definition/reverse-invocation-of-the-dify-service)
|
||
* [Tools](/en/plugins/schema-definition/tool)
|
||
* [Models](/en/plugins/schema-definition/model)
|
||
|
||
{/*
|
||
Contributing Section
|
||
DO NOT edit this section!
|
||
It will be automatically generated by the script.
|
||
*/}
|
||
|
||
---
|
||
|
||
[Edit this page](https://github.com/langgenius/dify-docs/edit/main/plugin-dev-en/9243-customizable-model.mdx) | [Report an issue](https://github.com/langgenius/dify-docs/issues/new?template=docs.yml)
|
||
|