基于 LangChain 自定义 Embeddings

发布于 2024-01-29  107 次阅读


基于 LangChain 自定义 Embeddings

在 LangChain 中支持 OpenAI、LLAMA 等大模型 Embeddings 的调用接口,不过没有内置所有大模型,但是允许用户自定义 Embeddings 类型。
接下来以 ZhipuAI 为例,基于 LangChain 自定义 Embeddings。

设计思路

  • 要实现自定义 Embeddings,需要定义一个自定义类继承自 LangChain 的 Embeddings 基类,然后定义三个函数
    • _embed: 接受一个字符串,并返回一个存放 Embeddings 的 List[float],即模型的核心调用
    • embed_query: 用于对单个字符串 (query) 进行 embedding
    • embed_documents: 用于对字符串列表 (documents) 进行 embedding

第三方库

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env

自定义 Embedding

ZhipuAIEmbeddings

定义一个继承自 Embeddings 类的自定义 Embeddings 类:

class ZhipuAIEmbeddings(BaseModel, Embeddings):
    """`Zhipuai Embeddings` embedding models."""

    zhipuai_api_key: Optional[str] = None
    """Zhipuai application apikey"""

root_validator 接收一个函数作为参数,该函数包含需要校验的逻辑。函数应该返回一个字典,其中包含经过校验的数据。如果校验失败,则抛出一个 ValueError 异常。

装饰器 root_validator 确保导入了相关的包和并配置了相关的 API_Key 这里取巧,在确保导入 zhipuai model 后直接将 zhipuai.model_api 绑定到 client 上,减少和其他 Embeddings 类的差异。

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """
        验证环境变量或配置文件中的zhipuai_api_key是否可用。

        Args:

            values (Dict): 包含配置信息的字典,必须包含 zhipuai_api_key 的字段
        Returns:

            values (Dict): 包含配置信息的字典。如果环境变量或配置文件中未提供 zhipuai_api_key,则将返回原始值;否则将返回包含 zhipuai_api_key 的值。
        Raises:

            ValueError: zhipuai package not found, please install it with `pip install
            zhipuai`
        """
        values["zhipuai_api_key"] = get_from_dict_or_env(
            values,
            "zhipuai_api_key",
            "ZHIPUAI_API_KEY",
        )

        try:
            import zhipuai
            zhipuai.api_key = values["zhipuai_api_key"]
            values["client"] = zhipuai.model_api

        except ImportError:
            raise ValueError(
                "Zhipuai package not found, please install it with "
                "`pip install zhipuai`"
            )
        return values

Override _embed

    def _embed(self, texts: str) -> List[float]:
        """
        生成输入文本的 embedding。

        Args:
            texts (str): 要生成 embedding 的文本。

        Return:
            embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表。
        """
        try:
            resp = self.client.invoke(
                model="text_embedding",
                prompt=texts
            )
        except Exception as e:
            raise ValueError(f"Error raised by inference endpoint: {e}")

        if resp["code"] != 200:
            raise ValueError(
                "Error raised by inference API HTTP code: %s, %s"
                % (resp["code"], resp["msg"])
            )
        embeddings = resp["data"]["embedding"]
        return embeddings

Override embed_documents

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        生成输入文本列表的 embedding。
        Args:
            texts (List[str]): 要生成 embedding 的文本列表.

        Returns:
            List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
        """
        return [self._embed(text) for text in texts]

Override embed_query

embed_query 是对单个文本计算 embedding 的方法,因为我们已经定义好对文档列表计算 embedding 的方法 embed_documents 了,这里可以直接将单个文本组装成 list 的形式传给 embed_documents

    def embed_query(self, text: str) -> List[float]:
        """
        生成输入文本的 embedding。

        Args:
            text (str): 要生成 embedding 的文本。

        Return:
            List [float]: 输入文本的 embedding,一个浮点数值列表。
        """
        resp = self.embed_documents([text])
        return resp[0]

本当の声を響かせてよ