import argparse
from loguru import logger
from src.datasets.xinhua import get_task_datasets
from evaluator import BaseEvaluator
from src.llms import GPT
from src.llms import Qwen_7B_Chat
from src.tasks.summary import Summary
from src.tasks.continue_writing import ContinueWriting
from src.tasks.hallucinated_modified import HalluModified
from src.tasks.quest_answer import QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs
from src.retrievers import BaseRetriever, CustomBM25Retriever, EnsembleRetriever, EnsembleRerankRetriever
from src.embeddings.base import HuggingfaceEmbeddings

parser = argparse.ArgumentParser()

# Model related options
parser.add_argument('--model_name', default='qwen7b', help="Name of the model to use")
parser.add_argument('--temperature', type=float, default=0.1, help="Controls the randomness of the model's text generation")
parser.add_argument('--max_new_tokens', type=int, default=1280, help="Maximum number of new tokens to be generated by the model")

# Dataset related options
parser.add_argument('--data_path', default='data/crud_split/split_merged.json', help="Path to the dataset")
parser.add_argument('--shuffle', type=bool, default=True, help="Whether to shuffle the dataset")

parser.add_argument('--embedding_name', default='sentence-transformers/bge-base-zh-v1.5')
parser.add_argument('--embedding_dim', type=int, default=768)

# Index related options
parser.add_argument('--docs_path', default='data/tmp', help="Path to the retrieval documents")
parser.add_argument('--docs_type', default="txt", help="Type of the documents")
parser.add_argument('--chunk_size', type=int, default=128, help="Chunk size")
parser.add_argument('--chunk_overlap', type=int, default=0, help="Overlap chunk size")
parser.add_argument('--construct_index', action='store_true', help="Whether to construct an index")
parser.add_argument('--add_index', action='store_true', default=False, help="Whether to add an index")
parser.add_argument('--collection_name', default="docs_80k_chuncksize_128_0", help="Name of the collection")

# Retriever related options
parser.add_argument('--retrieve_top_k', type=int, default=8, help="Top k documents to retrieve")
parser.add_argument('--retriever_name', default="base", help="Name of the retriever")

# Metric related options
parser.add_argument('--quest_eval', action='store_true', help="Whether to use QA metrics(RAGQuestEval)")
parser.add_argument('--bert_score_eval', action='store_true', help="Whether to use bert_score metrics")

# Evaluation related options
parser.add_argument('--task', default='event_summary', help="Task to perform")
parser.add_argument('--num_threads', type=int, default=1, help="Number of threads")
parser.add_argument('--show_progress_bar', action='store', default=True, type=bool, help="Whether to show a progress bar")
parser.add_argument('--contain_original_data', action='store_true', help="Whether to contain original data")

args = parser.parse_args()
logger.info(args)

if args.model_name.startswith("gpt"):
    llm = GPT(model_name=args.model_name, temperature=args.temperature, max_new_tokens=args.max_new_tokens)
elif args.model_name == "qwen7b":
    llm = Qwen_7B_Chat(model_name=args.model_name, temperature=args.temperature, max_new_tokens=args.max_new_tokens)

embed_model = HuggingfaceEmbeddings(model_name=args.embedding_name)

if args.retriever_name == "base":
    retriever = BaseRetriever(
        args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
        chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
        construct_index=args.construct_index, add_index=args.add_index,
        collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
    )
elif args.retriever_name == "bm25":
    retriever = CustomBM25Retriever(
        args.docs_path, embed_model=embed_model, chunk_size=args.chunk_size, 
        construct_index=args.construct_index,
        chunk_overlap=args.chunk_overlap, similarity_top_k=args.retrieve_top_k
    )
elif args.retriever_name == "hybrid":
    retriever = EnsembleRetriever(
        args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
        chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
        construct_index=args.construct_index, add_index=args.add_index,
        collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
    )
elif args.retriever_name == "hybrid-rerank":
    retriever = EnsembleRerankRetriever(
        args.docs_path, embed_model=embed_model, embed_dim=args.embedding_dim,
        chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap,
        construct_index=args.construct_index, add_index=args.add_index,
        collection_name=args.collection_name, similarity_top_k=args.retrieve_top_k
    )
else:
    raise ValueError(f"Unknown retriever: {args.retriever_name}")

task_mapping = {
    'event_summary':[Summary],
    'continuing_writing': [ContinueWriting],
    'hallu_modified': [HalluModified],
    'quest_answer': [QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs],
    'all': [Summary, ContinueWriting, HalluModified]#, QuestAnswer1Doc, QuestAnswer2Docs, QuestAnswer3Docs
}

if args.task not in task_mapping:
    raise ValueError(f"Unknown task: {args.task}")

tasks = [task(use_quest_eval=args.quest_eval, use_bert_score=args.bert_score_eval) for task in task_mapping[args.task]]

datasets = get_task_datasets(args.data_path, args.task)

for task, dataset in zip(tasks, datasets):
    evaluator = BaseEvaluator(task, llm, retriever, dataset, num_threads=args.num_threads)
    evaluator.run(show_progress_bar=args.show_progress_bar, contain_original_data=args.contain_original_data)

# CUDA_VISIBLE_DEVICES=0,1 nohup python quick_start.py --model_name 'qwen7b' --temperature 0.1 --max_new_tokens 1280 --data_path 'data/crud_split/split_merged.json' --shuffle True --docs_path 'data/new_qa' --docs_type 'txt' --chunk_size 128 --chunk_overlap 0 --retriever_name 'base' --collection_name 'docs_80k_chuncksize_128_0_new_qa_1' --retrieve_top_k 8 --task 'quest_answer' --num_threads 1 --show_progress_bar True --construct_index --bert_score_eval &