Source code for xrag.embs.chatglmemb
from typing import Optional, List, Mapping, Any, Sequence, Dict
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from zhipuai import ZhipuAI
[docs]
class ChatGLMEmbeddings(BaseEmbedding):
model: str = Field(default='embedding-2', description="The ChatGlM model to use. embedding-2")
api_key: str = Field(default=None, description="The ChatGLM API key.")
reuse_client: bool = Field(default=True, description=(
"Reuse the client between requests. When doing anything with large "
"volumes of async API calls, setting this to false can improve stability."
),
)
_client: Optional[Any] = PrivateAttr()
def __init__(
self,
model: str = 'embedding-2',
reuse_client: bool = True,
api_key: Optional[str] = None,
**kwargs: Any,
)-> None:
super().__init__(
model=model,
api_key=api_key,
reuse_client=reuse_client,
**kwargs,
)
self._client = None
def _get_client(self) -> ZhipuAI:
if not self.reuse_client :
return ZhipuAI(api_key=self.api_key)
if self._client is None:
self._client = ZhipuAI(api_key=self.api_key)
return self._client
[docs]
@classmethod
def class_name(cls) -> str:
return "ChatGLMEmbedding"
def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_general_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return self.get_general_text_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self.get_general_text_embedding(text)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return self.get_general_text_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get text embeddings."""
embeddings_list: List[List[float]] = []
for text in texts:
embeddings = self.get_general_text_embedding(text)
embeddings_list.append(embeddings)
return embeddings_list
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return self._get_text_embeddings(texts)
[docs]
def get_general_text_embedding(self, prompt: str) -> List[float]:
response = self._get_client().embeddings.create(
model=self.model, #填写需要调用的模型名称
input=prompt,
)
return response.data[0].embedding