Source code for xrag.retrievers.retriever

# pip install llama-index-retrievers-bm25

import os
import warnings
from typing import List, Optional

import openai
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.indices.document_summary import DocumentSummaryIndexEmbeddingRetriever, \
    DocumentSummaryIndexLLMRetriever
from llama_index.core.indices.keyword_table.retrievers import BaseKeywordTableRetriever, KeywordTableGPTRetriever
from llama_index.core.indices.list import SummaryIndexEmbeddingRetriever, SummaryIndexRetriever
from llama_index.core.indices.tree import TreeAllLeafRetriever, TreeSelectLeafRetriever, \
    TreeSelectLeafEmbeddingRetriever, TreeRootRetriever
from llama_index.core.postprocessor import MetadataReplacementPostProcessor, SimilarityPostprocessor, LongContextReorder
from llama_index.core.query_engine import RetrieverQueryEngine

import logging
import sys

from llama_index.core import (
    SimpleDirectoryReader,
    VectorStoreIndex, Settings, SummaryIndex, TreeIndex, QueryBundle, StorageContext, DocumentSummaryIndex,
)
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.retrievers.fusion_retriever import FUSION_MODES
from llama_index.core.schema import NodeWithScore, IndexNode
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.legacy.indices.keyword_table import KeywordTableSimpleRetriever
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.retrievers import VectorIndexRetriever, QueryFusionRetriever, AutoMergingRetriever, \
    RecursiveRetriever
from llama_index.core.node_parser import SentenceSplitter, SentenceWindowNodeParser
from llama_index.llms.openai import OpenAI

from llama_index.core import get_response_synthesizer

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().handlers = []
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))


# todo 常规检索模式: bm25检索 向量检索
[docs] def bm25_retriever(index): retriever_bm25 = BM25Retriever.from_defaults(index=index, similarity_top_k=3) return retriever_bm25
[docs] def vector_retriever(index): # index = VectorStoreIndex(node) retriever_vector = VectorIndexRetriever(index=index, similarity_top_k=3, show_progress=True, store_nodes_override=True) return retriever_vector
# todo 特定于某种索引的检索器类:汇总索引检索 树索引检索 关键字表索引检索 文档摘要索引检索 # note index必须为汇总索引 https://docs.llamaindex.ai/en/stable/api_reference/indices/list.html#llama_index.core.indices.list.SummaryIndex
[docs] def summary_retriever(summary_index, mode:int = 0): if mode > 2 or mode < 0: raise ValueError("Invalid mode for summary retriever."+ str(mode)) if mode == 0: # 最简单的文档汇总索引 retriever_summary = SummaryIndexRetriever(index=summary_index) elif mode == 1: # 基于编码的文档汇总索引 retriever_summary = SummaryIndexEmbeddingRetriever(index=summary_index, embed_model=Settings.embed_model, similarity_top_k=3) elif mode == 2: # 基于大模型的文档汇总索引 retriever_summary = SummaryIndexEmbeddingRetriever(index=summary_index, llm=Settings.llm, similarity_top_k=3) return retriever_summary
# note index必须为树索引 https://docs.llamaindex.ai/en/latest/api_reference/indices/tree.html#llama_index.core.indices.tree.TreeIndex
[docs] def tree_retriever(index, mode=0): if mode > 3 or mode < 0: mode = 0 # 只使用叶节点的树索引 if mode == 1: retriever_tree = TreeAllLeafRetriever(index=index) # 使用部分叶节点 elif mode == 2: retriever_tree = TreeSelectLeafRetriever(index=index) # 使用编码结合叶节点的检索器 elif mode == 3: retriever_tree = TreeSelectLeafEmbeddingRetriever(index=index, embed_model=Settings.embed_model) # 使用根节点的检索器 elif mode == 0: retriever_tree = TreeRootRetriever(index=index) return retriever_tree
# node index必须为关键字表索引 https://docs.llamaindex.ai/en/stable/api_reference/query/retrievers/table.html
[docs] def keyword_retriever(index, mode=0): if mode > 1 or mode < 0: raise ValueError("Invalid mode for keyword retriever."+ str(mode)) # 基本关键字表检索器 # GPT关键字表检索器 retriever_keyword = KeywordTableGPTRetriever(index) return retriever_keyword
# 关于该检索器可能的bug:https://github.com/run-llama/llama_index/issues/7633
[docs] def document_summary_retrievers(index): retriever_d = DocumentSummaryIndexLLMRetriever( index, choice_batch_size=10, choice_top_k=1, ) retriever_d = DocumentSummaryIndexEmbeddingRetriever(index) return retriever_d
# todo 高级检索模式:自定义检索器 融合检索 自动合并检索(扩增上下文检索) 元数据替换+句子窗口检索 递归检索 # 自定义检索器,定义了向量检索与bm25检索混合,混合模式为取合
[docs] class CustomRetriever(BaseRetriever): def __init__( self, vector_retriever_c: VectorIndexRetriever, bm25_retriever_c: BM25Retriever, keyword_retriever_c: KeywordTableSimpleRetriever, mode: str = "AND", ) -> None: self._vector_retriever = vector_retriever_c self._bm25_retriever = bm25_retriever_c self._keyword_retriever = keyword_retriever_c if mode not in ("AND", "OR"): raise ValueError("Invalid mode.") self._mode = mode super().__init__() def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: vector_nodes = self._vector_retriever.retrieve(query_bundle) bm25_nodes = self._bm25_retriever.retrieve(query_bundle) vector_ids = {n.node.node_id for n in vector_nodes} bm25_ids = {n.node.node_id for n in bm25_nodes} combined_dict = {n.node.node_id: n for n in vector_nodes} combined_dict.update({n.node.node_id: n for n in bm25_nodes}) if self._mode == "AND": retrieve_ids = vector_ids.intersection(bm25_ids) else: retrieve_ids = vector_ids.union(bm25_ids) retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids] return retrieve_nodes
# 融合检索器 其将来自多个文档的索引作为输入,并自动进行问题扩充,以获得多次查询结果用于合并 # mode: 代表融合模式. 0 代表简单合并. 1: 采用RRF倒数排名融合
[docs] def query_fusion_retriever(index, num_queries=4, similarity_top_k=2, mode=0, retriever_weight=None): query_fusion_r = None if not isinstance(index, list): index = [index] if retriever_weight is None: retriever_weight = [1 / len(index)] * len(index) if mode == 0: query_fusion_r = QueryFusionRetriever( [index_s.as_retriever() for index_s in index], llm=Settings.llm, similarity_top_k=similarity_top_k, num_queries=num_queries, use_async=True, verbose=True, # query_gen_prompt="...", # 默认用于多轮询问的问题模板: # QUERY_GEN_PROMPT=( # "You are a helpful assistant that generates multiple search queries based on a " # "single input query. Generate {num_queries} search queries, one on each line, " # "related to the following input query:\n" # "Query: {query}\n" # "Queries:\n" # ) ) elif mode == 1: query_fusion_r = QueryFusionRetriever( [index_s.as_retriever() for index_s in index], llm=Settings.llm, similarity_top_k=similarity_top_k, num_queries=num_queries, use_async=True, verbose=True, mode=FUSION_MODES.RECIPROCAL_RANK ) return query_fusion_r
# 自动合并检索 # 其将自动合并子节点为高级节点,并进行扩增上下文操作 故需要提供节点类 # 也可使用特殊节点类进行初始化 https://docs.llamaindex.ai/en/latest/examples/retrievers/auto_merging_retriever.html
[docs] def auto_merging_retriever(index, hierarchical_storage_context): auto_merging_r = AutoMergingRetriever(index.as_retriever(similarity_top_k=6), storage_context=hierarchical_storage_context, verbose=True) return auto_merging_r
# 递归检索 + 句子节点引用rong
[docs] def recursive_retriever(base_nodes): sub_chunk_sizes = [128, 256, 512] sub_node_parsers = [ SentenceSplitter(chunk_size=c, chunk_overlap=20) for c in sub_chunk_sizes ] all_nodes = [] for base_node in base_nodes: for n in sub_node_parsers: sub_nodes = n.get_nodes_from_documents([base_node]) sub_inodes = [ IndexNode.from_text_node(sn, base_node.node_id) for sn in sub_nodes ] all_nodes.extend(sub_inodes) # also add original node to node original_node = IndexNode.from_text_node(base_node, base_node.node_id) all_nodes.append(original_node) all_nodes_dict = {n.node_id: n for n in all_nodes} vector_index_chunk = VectorStoreIndex(all_nodes, embed_model=Settings.embed_model) vector_retriever_chunk = vector_index_chunk.as_retriever(similarity_top_k=3) retriever_chunk = RecursiveRetriever( "vector", retriever_dict={"vector": vector_retriever_chunk}, node_dict=all_nodes_dict, verbose=True, ) return vector_retriever_chunk
# 句子窗口类需要特殊的节点生成的索引
[docs] def sentence_window_retriever(index): return index.as_retriever(node_postprocessors=[ MetadataReplacementPostProcessor(target_metadata_key="window") ], )
[docs] def custom_retriever(index, mode=0): vector_r = vector_retriever(index) bm25_r = bm25_retriever(index) keyword_r = keyword_retriever(index) if mode > 1 or mode < 0: raise ValueError("Invalid mode for custom retriever."+ str(mode)) if mode == 0: custom_r = CustomRetriever(vector_retriever_c=vector_r, bm25_retriever_c=bm25_r, mode="AND") elif mode == 1: custom_r = CustomRetriever(vector_retriever_c=vector_r, keyword_retriever_c=keyword_r, mode="OR") return custom_r
# refine:通过按顺序遍历每个检索到的文本块来创建和优化答案。 这将为每个节点/检索到的块进行单独的 LLM 调用。 # compact(默认):事先压缩(连接)块,从而减少 LLM 调用。填充尽可能多的文本(从检索到的块中串联/打包)可以放入上下文窗口 # tree_summarize:根据需要多次使用提示查询 LLM,以便所有串联的块 被查询,从而产生同样多的答案,这些答案本身在 LLM 调用中以递归方式用作块 依此类推,直到只剩下一个块,因此只有一个最终答案。 # simple_summarize:截断所有文本块以适应单个 LLM 提示。适合快速 摘要目的,但可能会因截断而丢失详细信息。 # accumulate:给定一组文本块和查询,将查询应用于每个文本 块,同时将响应累积到数组中。返回 all 的串联字符串。当您需要对每个文本分别运行相同查询时,非常有用 # compact_accumulate:与 accumulate 相同,但会“压缩”每个类似于 的 LLM 提示符,并对每个文本块运行相同查询。compact
[docs] def response_synthesizer(mode=0): if mode > 7: mode = 1 choose = [ResponseMode.REFINE, ResponseMode.COMPACT, ResponseMode.COMPACT_ACCUMULATE, ResponseMode.ACCUMULATE, ResponseMode.TREE_SUMMARIZE, ResponseMode.SIMPLE_SUMMARIZE, ResponseMode.NO_TEXT, ResponseMode.GENERATION] response_s = get_response_synthesizer(response_mode=choose[mode]) return response_s
# 检索器类型包括: # BM25 # Vector # Summary: 汇总检索器 mode = 0 1 2 对应 最简单的文档汇总索引 基于编码的文档汇总索引 基于大模型的文档汇总索引 必须为关键字表索引 * # Keyword: 关键字表检索器 mode = 0 1 对应 基本关键字表检索器 GPT关键字表检索器 必须为关键字表索引 * # Custom # QueryFusion: 融合检索器 其将来自多个文档的索引作为输入,并自动进行问题扩充,以获得多次查询结果用于合并,index需为一个list # mode = 0, 1 分别使用llama simple重排序方法,1采用RRF # AutoMerging # Recursive # SentenceWindow * # Tree: index必须为树索引 * # mode: 1,2,3,0 对应 TreeAllLeafRetriever TreeSelectLeafRetriever TreeSelectLeafEmbeddingRetriever TreeRootRetriever # mode确定检索器的模式
[docs] def get_retriver(type: str, index, mode: int = 0, node = None, hierarchical_storage_context = None): if type == "BM25": retriever = bm25_retriever(index) elif type == "Vector": retriever = vector_retriever(index) elif type == "Summary": retriever = summary_retriever(index, mode=mode) elif type == "Tree": retriever = tree_retriever(index, mode=mode) elif type == "Keyword": retriever = keyword_retriever(index, mode=mode) elif type == "Custom": retriever = custom_retriever(index, mode=mode) elif type == "QueryFusion": retriever = query_fusion_retriever(index, mode=mode) elif type == "AutoMerging": retriever = auto_merging_retriever(index, hierarchical_storage_context) elif type == "Recursive": retriever = recursive_retriever(base_nodes=node) elif type == "SentenceWindow": retriever = sentence_window_retriever(index) else: raise Exception("retriever not supported: %s" % mode) return retriever
[docs] def query_expansion(ret, query_number=4, similarity_top_k=10): if ret is None: ret = [] warnings.warn("query_expansion未传入检索器") return QueryFusionRetriever( ret, similarity_top_k=similarity_top_k, num_queries=query_number, # set this to 1 to disable query generation use_async=True, verbose=True, # query_gen_prompt="...", # we could override the query generation prompt here )
[docs] class AllRetriever: _doc = None, _nodes = None, _index = None, _retriever = [] _query_number = 4, _similar_k_top = 10, _syn = 0 def __init__(self, nodes_, vector_index_, summary_index_, tree_index_, keyword_index_, sentence_index_, mode=0): self.bm25_retriever = bm25_retriever(vector_index_) self._retriever.append(self.bm25_retriever) self.vector_retriever = vector_retriever(vector_index_) self._retriever.append(self.vector_retriever) self.summary_retriever = summary_retriever(summary_index_) self._retriever.append(self.summary_retriever) self.tree_retriever = tree_retriever(tree_index_) self._retriever.append(self.tree_retriever) self.keyword_retriever = keyword_retriever(keyword_index_) self._retriever.append(self.keyword_retriever) self.doc_s_retriever = document_summary_retrievers(vector_index_) self._retriever.append(self.doc_s_retriever) self.custom_retriever = custom_retriever(vector_index_) self._retriever.append(self.custom_retriever) self.query_fusion_retriever = query_fusion_retriever(vector_index_, mode=1) self._retriever.append(self.query_fusion_retriever) self.auto_merging_retriever = auto_merging_retriever(vector_index_, nodes_) self._retriever.append(self.auto_merging_retriever) self.recursive_retriever = recursive_retriever(nodes_) self._retriever.append(self.recursive_retriever) self.router_retriever = get_query_engine_by_router(summary_index_, vector_index_, keyword_index_) self._retriever.append(self.router_retriever) self.sentence_window_retriever = sentence_window_retriever(sentence_index_) self._retriever.append(self.sentence_window_retriever) self.response_syn = response_synthesizer(mode)
[docs] def query_expansion(self, retriever, query_number: Optional[int], similarity_number: Optional[int]): if query_number is not None: self._query_number = query_number if similarity_number is not None: self._similar_k_top = similarity_number return query_expansion([retriever], query_number=self._query_number, similarity_top_k=self._similar_k_top)
[docs] def get_response_mode(self, retriever_, mode=0): if mode != 0: self._syn = mode query_e = RetrieverQueryEngine( retriever=retriever_, response_synthesizer=response_synthesizer(mode), node_postprocessors=[LongContextReorder()] ) return query_e
from llama_index.core.query_engine import RouterQueryEngine from llama_index.core.selectors import PydanticSingleSelector from llama_index.core.tools import QueryEngineTool from llama_index.core import VectorStoreIndex, SummaryIndex # define query engines ...
[docs] def get_query_engine_by_router(summary_index=None, vector_index=None, keyword_index=None): summary_tool = None vector_tool = None keyword_tool = None if summary_index is None: warnings.warn("Summary_index is None") else: summary_tool = QueryEngineTool.from_defaults( query_engine=summary_retriever(summary_index), description="Useful for summarization questions related to the data source", ) if vector_index is None: warnings.warn("vector_index is None") else: vector_tool = QueryEngineTool.from_defaults( query_engine=custom_retriever(vector_index), description="Useful for retrieving specific context related to the data source", ) if keyword_index is None: warnings.warn("keyword_index is None") else: keyword_tool = QueryEngineTool.from_defaults( query_engine=keyword_retriever(keyword_index), description="Useful for retrieving keyword related to the data source", ) query_engine_ = RouterQueryEngine( selector=PydanticSingleSelector.from_defaults(), query_engine_tools=[ summary_tool, vector_tool, keyword_tool ], ) return query_engine_
if __name__ == '__main__': Settings.llm = OpenAI(temperature=0.2, model="gpt-3.5-turbo") # 需要一个直接放文件的本地目录 documents = SimpleDirectoryReader(f"E:\\junior_second_\\benchmark_llm\\RAG-benchmark").load_data() index_ = VectorStoreIndex.from_documents(documents, transformations=[SentenceSplitter(chunk_size=512, chunk_overlap=20)], show_progress=True) # splitter = SentenceSplitter( # chunk_size=1024, # chunk_overlap=20, # ) # nodes = splitter.get_nodes_from_documents(documents) # index_ = VectorStoreIndex(nodes) retriever = vector_retriever(index_) # 可用 # retriever = vector_retriever(index_) # 可用 # retriever = summary_retriever(SummaryIndex(nodes)) # 可用 # retriever = summary_retriever(index_) # retriever = tree_retriever(TreeIndex(nodes)) # 可用,看起来效果很差 # retriever = keyword_retriever(index_) # 未测试 # vector_retriever = vector_retriever(index=index_) # bm25_retriever = bm25_retriever(index=index_) # retriever = CustomRetriever(vector_retriever=vector_retriever, bm25_retriever=bm25_retriever) # 可用 # retriever = query_fusion_retriever([index_, index_], mode=1) #可用 # retriever = auto_merging_retriever(index_, nodes) # 可用 # node_parser = SentenceWindowNodeParser.from_defaults( # window_size=3, # window_metadata_key="window", # original_text_metadata_key="original_text", # ) # retriever = sentence_window_retriever(index_) # 可用 # retriever = recursive_retriever(nodes) # 可用 response_syn = response_synthesizer(0) nodes = retriever.retrieve("请用中文回答我的毕业设计题目是什么") print(nodes) query_engine = RetrieverQueryEngine( retriever=get_retriver("QueryFusion", index_, mode=0), response_synthesizer=response_syn, node_postprocessors=[LongContextReorder()] ) query_e = query_expansion([query_engine],query_number=4,similarity_top_k=3) query_engine = RetrieverQueryEngine.from_args(query_e) response = query_engine.query("请用中文回答我的毕业设计题目是什么") print(response)