Source code for xrag.launcher.launch

import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core import Settings, PromptTemplate
from ..llms import get_llm
from ..index import get_index
from ..eval.evaluate_rag import evaluating
from ..embs.embedding import get_embedding
from ..data.qa_loader import get_qa_dataset
from ..config import Config
from ..retrievers.retriever import get_retriver, query_expansion, response_synthesizer
import warnings
from ..eval.evaluate_rag import EvaluationResult
from ..eval.EvalModelAgent import EvalModelAgent
from ..process.postprocess_rerank import get_postprocessor
from ..process.query_transform import transform_and_query
import random
import numpy as np
import torch
[docs] def seed_everything(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True
[docs] def build_index(documents): cfg = Config() llm = get_llm(cfg.llm) # Create and dl embeddings instance embeddings = get_embedding(cfg.embeddings) Settings.chunk_size = cfg.chunk_size Settings.llm = llm Settings.embed_model = embeddings # pip install llama-index-embeddings-langchain cfg.persist_dir = cfg.persist_dir + '-' + cfg.dataset + '-' + cfg.embeddings + '-' + cfg.split_type + '-' + str( cfg.chunk_size) index, hierarchical_storage_context = get_index(documents, cfg.persist_dir, split_type=cfg.split_type, chunk_size=cfg.chunk_size) return index, hierarchical_storage_context
[docs] def build_query_engine(index, hierarchical_storage_context): cfg = Config() query_engine = RetrieverQueryEngine( retriever=get_retriver(cfg.retriever, index, hierarchical_storage_context=hierarchical_storage_context), response_synthesizer=response_synthesizer(0), node_postprocessors=[get_postprocessor(cfg)] ) text_qa_template_str = cfg.text_qa_template_str text_qa_template = PromptTemplate(text_qa_template_str) refine_template_str = cfg.refine_template_str refine_template = PromptTemplate(refine_template_str) query_engine.update_prompts({"response_synthesizer:text_qa_template": text_qa_template, "response_synthesizer:refine_template": refine_template}) query_engine = query_expansion([query_engine], query_number=4, similarity_top_k=10) query_engine = RetrieverQueryEngine.from_args(query_engine) return query_engine
[docs] def eval_cli(qa_dataset, query_engine): cfg = Config() true_num = 0 all_num = 0 evaluateResults = EvaluationResult(metrics=cfg.metrics) evalAgent = EvalModelAgent(cfg) if cfg.experiment_1: if len(qa_dataset) < cfg.test_init_total_number_documents: warnings.filterwarnings('default') warnings.warn("使用的数据集长度大于数据集本身的最大长度,请修改。 本轮代码无法运行", UserWarning) else: cfg.test_init_total_number_documents = cfg.n for question, expected_answer, golden_context, golden_context_ids in zip( qa_dataset['test_data']['question'][:cfg.test_init_total_number_documents], qa_dataset['test_data']['expected_answer'][:cfg.test_init_total_number_documents], qa_dataset['test_data']['golden_context'][:cfg.test_init_total_number_documents], qa_dataset['test_data']['golden_context_ids'][:cfg.test_init_total_number_documents] ): response = transform_and_query(question, cfg, query_engine) # 返回node节点 retrieval_ids = [] retrieval_context = [] for source_node in response.source_nodes: retrieval_ids.append(source_node.metadata['id']) retrieval_context.append(source_node.get_content()) actual_response = response.response eval_result = evaluating(question, response, actual_response, retrieval_context, retrieval_ids, expected_answer, golden_context, golden_context_ids, evaluateResults.metrics, evalAgent) evaluateResults.add(eval_result) all_num = all_num + 1 evaluateResults.print_results() print("总数:" + str(all_num)) return evaluateResults
[docs] def run(cli=True): seed_everything(42) cfg = Config() qa_dataset = get_qa_dataset(cfg.dataset) index, hierarchical_storage_context = build_index(qa_dataset['documents']) query_engine = build_query_engine(index, hierarchical_storage_context) if cli: evaluateResults = eval_cli(qa_dataset, query_engine) return evaluateResults else: return query_engine, qa_dataset
if __name__ == '__main__': run() print('Success')