基于 LangChain 自定义 Embeddings
在 LangChain 中支持 OpenAI、LLAMA 等大模型 Embeddings 的调用接口,不过没有内置所有大模型,但是允许用户自定义 Embeddings 类型。
接下来以 ZhipuAI 为例,基于 LangChain 自定义 Embeddings。
设计思路
- 要实现自定义 Embeddings,需要定义一个自定义类继承自 LangChain 的 Embeddings 基类,然后定义三个函数
_embed
: 接受一个字符串,并返回一个存放 Embeddings 的 List[float],即模型的核心调用embed_query
: 用于对单个字符串 (query) 进行 embeddingembed_documents
: 用于对字符串列表 (documents) 进行 embedding
第三方库
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env
自定义 Embedding
ZhipuAIEmbeddings
定义一个继承自 Embeddings 类的自定义 Embeddings 类:
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""`Zhipuai Embeddings` embedding models."""
zhipuai_api_key: Optional[str] = None
"""Zhipuai application apikey"""
root_validator
接收一个函数作为参数,该函数包含需要校验的逻辑。函数应该返回一个字典,其中包含经过校验的数据。如果校验失败,则抛出一个 ValueError
异常。
装饰器 root_validator
确保导入了相关的包和并配置了相关的 API_Key 这里取巧,在确保导入 zhipuai model 后直接将 zhipuai.model_api
绑定到 client 上,减少和其他 Embeddings 类的差异。
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
验证环境变量或配置文件中的zhipuai_api_key是否可用。
Args:
values (Dict): 包含配置信息的字典,必须包含 zhipuai_api_key 的字段
Returns:
values (Dict): 包含配置信息的字典。如果环境变量或配置文件中未提供 zhipuai_api_key,则将返回原始值;否则将返回包含 zhipuai_api_key 的值。
Raises:
ValueError: zhipuai package not found, please install it with `pip install
zhipuai`
"""
values["zhipuai_api_key"] = get_from_dict_or_env(
values,
"zhipuai_api_key",
"ZHIPUAI_API_KEY",
)
try:
import zhipuai
zhipuai.api_key = values["zhipuai_api_key"]
values["client"] = zhipuai.model_api
except ImportError:
raise ValueError(
"Zhipuai package not found, please install it with "
"`pip install zhipuai`"
)
return values
Override _embed
def _embed(self, texts: str) -> List[float]:
"""
生成输入文本的 embedding。
Args:
texts (str): 要生成 embedding 的文本。
Return:
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表。
"""
try:
resp = self.client.invoke(
model="text_embedding",
prompt=texts
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
if resp["code"] != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (resp["code"], resp["msg"])
)
embeddings = resp["data"]["embedding"]
return embeddings
Override embed_documents
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
生成输入文本列表的 embedding。
Args:
texts (List[str]): 要生成 embedding 的文本列表.
Returns:
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
"""
return [self._embed(text) for text in texts]
Override embed_query
embed_query
是对单个文本计算 embedding 的方法,因为我们已经定义好对文档列表计算 embedding 的方法 embed_documents
了,这里可以直接将单个文本组装成 list 的形式传给 embed_documents
。
def embed_query(self, text: str) -> List[float]:
"""
生成输入文本的 embedding。
Args:
text (str): 要生成 embedding 的文本。
Return:
List [float]: 输入文本的 embedding,一个浮点数值列表。
"""
resp = self.embed_documents([text])
return resp[0]
Comments | NOTHING