import csv
import json
import argparse
import os
import torch
import random
import transformers
import time
import re
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest
from tqdm import tqdm
import logging
import sys
from datasets import load_dataset
import numpy as np
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
from vllm.distributed.parallel_state import destroy_model_parallel
import gc


os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"]

max_model_length = 4096
max_new_tokens = 2048


seed= 12345
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

def load_arc():
    dataset = load_dataset('tau/commonsense_qa')
    test_df, val_df = dataset["validation"], dataset["train"]
    test_df = preprocess_arc(test_df)
    val_df = preprocess_arc(val_df)
    return test_df, val_df

def preprocess_arc(test_df):
    res_df = []
    for each in test_df:
        options = []
        labels = []
        for opt,label in zip(each["choices"]["text"],each["choices"]["label"]):
            if opt == "N/A":
                continue
            options.append(opt)
            labels.append(label)
        each["options"] = options
        each['answer_index'] = labels.index(each['answerKey'])
        each['answer'] = choices[each['answer_index']]
        each['category'] = 'csqa'
        res_df.append(each)
    return res_df

def load_model():
    llm = LLM(model=args.model, gpu_memory_utilization=float(args.gpu_util),
                tensor_parallel_size=torch.cuda.device_count(),
                max_model_len=max_model_length,
                trust_remote_code=True,
                enable_lora=True if args.lora_path else False)
    try:
        generation_config = transformers.GenerationConfig.from_pretrained(args.model)
    except Exception:
        generation_config = transformers.GenerationConfig()
    sampling_params = SamplingParams(temperature=generation_config.temperature if generation_config.temperature else 0.6, 
                                     top_k=generation_config.top_k if generation_config.top_k else 20, 
                                     top_p=generation_config.top_p if generation_config.top_p else 0.95, 
                                     min_p=generation_config.min_p if generation_config.min_p else 0., 
                                     max_tokens=max_new_tokens, stop=["Question:"], seed=seed)
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    if 'ministral' in str(args.model).lower():
        from transformers import MistralCommonBackend
        print('USING MINISTRAL TOKENIZER')
        tokenizer = MistralCommonBackend.from_pretrained(args.model)

    lora_request = None
    if args.lora_path:
        lora_request = LoRARequest(
            lora_name="lora",
            lora_path=args.lora_path,
            lora_int_id=1
        )

    return (llm, sampling_params, lora_request), tokenizer


def preprocess(test_df):
    res_df = []
    for each in test_df:
        options = []
        for opt in each["options"]:
            if opt == "N/A":
                continue
            options.append(opt)
        each["options"] = options
        res_df.append(each)
    return res_df


def args_generate_path(input_args):
    scoring_method = "CoT"
    model_name = input_args.model.split("/")[-1]
    subjects = args.selected_subjects.replace(",", "-").replace(" ", "_")
    return [model_name, scoring_method, subjects]

import copy 

def select_by_category(df, subject, validation=True):
    res = []
    for each in df:
        if each["category"] == subject:
            if args.swap_answer and choices.index(args.swap_answer) < len(each['options']) and not validation:
                # Swap answers
                temp_ans = copy.copy(each['options'][each['answer_index']])
                each['options'][each['answer_index']] = copy.copy(each['options'][choices.index(args.swap_answer)])
                each['options'][choices.index(args.swap_answer)] = temp_ans
                # Update answer and answer index
                each['answer_index'] = choices.index(args.swap_answer)
                each['answer'] = args.swap_answer
                
            res.append(each)
    return res


def format_cot_example(example, including_answer=True):
    prompt = "Question:\n"
    question = example["question"]
    options = example["options"]
    prompt += question + "\n"
    prompt += "Options:\n"
    for i, opt in enumerate(options):
        prompt += f"- {opt}\n"
    if including_answer:
        cot_content = f'Answer: The answer is {options[example['answer_index']]}'
        prompt += cot_content + "\n\n"
    else:
        prompt += "Answer: Let's think step by step."
    return prompt


def generate_cot_prompt(val_df, curr, k, dataset:str):
    prompt = ""
    with open(f"initial_prompt_M&D.txt", "r") as fi:
        for line in fi.readlines():
            prompt += line
    subject = curr["category"]
    val_df = select_by_category(val_df, subject, validation=True)
    val_df = val_df[: k]
    # Only MMLU-PRO has a few-shot prompt 
    for example in val_df:
        prompt += format_cot_example(example, including_answer=True)
    prompt += format_cot_example(curr, including_answer=False)
    return prompt, curr["options"]


def extract_answer(text, answers, prompt):
    pattern = r"answer is (?!.*answer is ).+"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    if match:
        query = str(text[match.start():]).replace('answer is','').split('\n')[0].strip()
        return query, 0 
    else:
        return extract_again(text, answers, prompt)

def extract_again(text, answers, prompt):
    pattern = r'.*[aA]nswer:\s*(?!.*[aA]nswer:\s*).+'
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    if match:
        query = str(text[match.start():]).replace('Answer:','').replace('answer:','').split('\n')[0].strip()
        return query, 1
    else:
        return extract_final(text,answers, prompt)


def extract_final(text, answers, prompt):
    position = []
    for answer in answers:
        pattern = re.escape(answer)
        match = re.search(pattern, text, re.DOTALL)
        if match:
            position.append(match.end())
        else:
            position.append(-1)
    best_match = np.argmax(position)
    if position[best_match]==-1:
        # Match the last sentence
        match = re.search(r'([^.!?]+[.!?]*$)', text, re.DOTALL)
        if match:
            query = match.group(1).strip()
            return query, 3
        else:
            return answers[random.randint(0,len(answers)-1)], 4
    return answers[best_match], 2

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery:{query}'

def batch_inference(llm:LLM, sampling_params, lora_request, inference_batch, answers_batch, query_prompt):
    start = time.time()
    outputs = llm.generate(inference_batch, sampling_params, lora_request=lora_request)
    logging.info(str(len(inference_batch)) + "size batch costing time: " + str(time.time() - start))
    response_batch = []
    pred_batch = []
    extraction_tries = []
    similarities_batch = []
    queries_batch = []

    answers_length = []
    for i,output in tqdm(enumerate(outputs), desc='Extracting answers...', total=len(outputs)):
        generated_text = output.outputs[0].text
        response_batch.append(generated_text)
        extracted_answer, extraction_try = extract_answer(generated_text, answers=answers_batch[i], prompt=query_prompt)
        queries_batch.append(get_detailed_instruct(query_prompt,extracted_answer))
        extraction_tries.append(extraction_try)
        answers_length.append(len(answers_batch[i]))

    # Due to the memory limitations, we destroy the LLM generator and load the LLM embedder here
    llm.llm_engine.engine_core.shutdown()
    destroy_model_parallel()
    del llm.llm_engine.model_executor
    del llm # Isn't necessary for releasing memory, but why not
    gc.collect()
    torch.cuda.empty_cache()

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize() 

    answer_model = LLM(model="Qwen/Qwen3-Embedding-0.6B", runner="pooling", gpu_memory_utilization=0.2,max_model_len=max_model_length*3)


    start = time.time()
    flat_answers = [ get_detailed_instruct(query_prompt,option) for answers in answers_batch for option in answers ]
    answers_embedd = answer_model.embed(queries_batch, use_tqdm=False)
    options_embedd = answer_model.embed(flat_answers, use_tqdm=False)
    
    indx = 0
    for i in range(0, len(answers_embedd)):
        query_embedding = torch.tensor(answers_embedd[i].outputs.embedding)
        sentence_embedding = torch.tensor([ answer.outputs.embedding  for answer in options_embedd[indx:indx+answers_length[i]] ])
        similarities = util.pytorch_cos_sim(query_embedding, sentence_embedding).flatten()
        ranking = similarities.argsort(descending=True)
        pred_batch.append(choices[int(ranking[0].item())])
        similarities_batch.append(similarities.tolist())
        indx += answers_length[i]
    logging.info(str(len(inference_batch)) + "size extraction costing time: " + str(time.time() - start))
    return pred_batch, response_batch, extraction_tries, similarities_batch, queries_batch, inference_batch

def save_res(res, output_path):
    accu, corr, wrong = 0.0, 0.0, 0.0
    with open(output_path, "w") as fo:
        fo.write(json.dumps(res, indent=4))
    for each in res:
        if not each["pred"]:
            x = random.randint(0, len(each["options"]) - 1)
            if x == each["answer_index"]:
                corr += 1
                # print("random hit.")
            else:
                wrong += 1
        elif each["pred"] == each["answer"]:
            corr += 1
        else:
            wrong += 1
    if corr + wrong == 0:
        return 0.0, 0.0, 0.0
    accu = corr / (corr + wrong)
    return accu, corr, wrong



def prepare_query_prompt():
    prompt = 'Given an answer to a multiple choice question, retrieve a similar answer that matches the query'
    return prompt
    
@torch.no_grad()
def eval_cot(subject, model, tokenizer, val_df, test_df, output_path, dataset:str):
    llm, sampling_params, lora_request = model
    global choices
    logging.info("evaluating " + subject)
    inference_batches = []
    answers_batches = []
    for i in tqdm(range(len(test_df))):
        k = args.ntrain
        curr = test_df[i]
        prompt_length_ok = False
        prompt = None
        while not prompt_length_ok:
            prompt, answers = generate_cot_prompt(val_df, curr, k, dataset)
            inputs = tokenizer(prompt, return_tensors="pt")
            inputs = {key: value.cuda() for key, value in inputs.items()}
            length = len(inputs["input_ids"][0])
            if length < max_model_length - max_new_tokens:
                prompt_length_ok = True
            k -= 1
        inference_batches.append(prompt)
        answers_batches.append(answers)

    query_prompt = prepare_query_prompt()
    pred_batch, response_batch, extract_try, similarities_batch, queries_batch, prompt_batch = batch_inference(llm, sampling_params, lora_request, inference_batches, answers_batches, query_prompt=query_prompt)
    res = []
    for j, curr in enumerate(test_df):
        curr["pred"] = pred_batch[j]
        curr["model_outputs"] = response_batch[j]
        curr['query'] = queries_batch[j]
        curr['extraction_try'] = extract_try[j]
        curr['similarities'] = similarities_batch[j]
        curr['simil_prob'] = list(np.array(similarities_batch[j])/float(sum(similarities_batch[j])))
        curr['prompt'] = prompt_batch[j]
        res.append(curr)
    accu, corr, wrong = save_res(res, output_path)
    logging.info("this batch accu is: {}, corr: {}, wrong: {}\n".format(str(accu), str(corr), str(wrong)))

    accu, corr, wrong = save_res(res, output_path)
    return accu, corr, wrong


def main():
    model, tokenizer = load_model()
    if not os.path.exists(save_result_dir):
        os.makedirs(save_result_dir)

    full_test_df, full_val_df = load_arc()
    all_subjects = []
    for each in full_test_df:
        if each["category"] not in all_subjects:
            all_subjects.append(each["category"])
    if args.selected_subjects == "all":
        selected_subjects = all_subjects
    else:
        selected_subjects = []
        args_selected = args.selected_subjects.split(",")
        for sub in all_subjects:
            for each in args_selected:
                if each.replace(" ", "_") in sub.replace(" ", "_"):
                    selected_subjects.append(sub)
    logging.info("selected subjects:\n" + "\n".join(selected_subjects))
    print("selected subjects:\n" + "\n".join(selected_subjects))
    sta_dict = {}
    selected_subjects = sorted(selected_subjects)
    with open(os.path.join(summary_path), 'a') as f:
        f.write("\n------category level sta------\n")
    for subject in selected_subjects:
        if subject not in sta_dict:
            sta_dict[subject] = {"corr": 0.0, "wrong": 0.0, "accu": 0.0}
        test_df = select_by_category(full_test_df, subject, validation=False)
        val_df = select_by_category(full_val_df, subject, validation=True)
        output_path = os.path.join(save_result_dir, "{}.json".format(subject))
        acc, corr_count, wrong_count = eval_cot(subject, model, tokenizer, val_df, test_df, output_path, dataset=args.dataset)
        sta_dict[subject]["corr"] = corr_count
        sta_dict[subject]["wrong"] = wrong_count
        sta_dict[subject]["accu"] = acc
        with open(os.path.join(summary_path), 'a') as f:
            f.write("Average accuracy {:.4f} - {}\n".format(sta_dict[subject]["accu"], subject))
    total_corr, total_wrong = 0.0, 0.0
    for k, v in sta_dict.items():
        total_corr += v["corr"]
        total_wrong += v["wrong"]
    total_accu = total_corr / (total_corr + total_wrong + 0.000001)
    sta_dict["total"] = {"corr": total_corr, "wrong": total_wrong, "accu": total_accu}

    with open(os.path.join(summary_path), 'a') as f:
        f.write("\n------average acc sta------\n")
        weighted_acc = total_accu
        f.write("Average accuracy: {:.4f}\n".format(weighted_acc))
    with open(global_record_file, 'a', newline='') as file:
        writer = csv.writer(file)
        record = args_generate_path(args) + [time_str, weighted_acc]
        writer.writerow(record)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--selected_subjects", "-sub", type=str, default="all")
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument("--global_record_file", "-grf", type=str,
                        default="eval_record_collection.csv")
    parser.add_argument("--gpu_util", "-gu", type=str, default="0.8")
    parser.add_argument("--model", "-m", type=str, default="meta-llama/Llama-2-7b-hf")
    parser.add_argument("--lora_path", "-lp", type=str, default=None)
    parser.add_argument("--swap_answer", "-sa", type=str, default=None)
    parser.add_argument("--dataset", "-d", type=str, default='mmlu-pro', choices=['mmlu-pro','mmlu','arc'])

    args = parser.parse_args()
    if args.lora_path and not os.path.exists(args.lora_path):
        raise ValueError(f"LoRA path {args.lora_path} does not exist.")
    os.makedirs(args.save_dir, exist_ok=True)
    global_record_file = args.global_record_file
    save_result_dir = os.path.join(
        args.save_dir, "/".join(args_generate_path(args))
    )
    file_prefix = "-".join(args_generate_path(args))
    timestamp = time.time()
    time_str = time.strftime('%m-%d_%H-%M', time.localtime(timestamp))
    file_name = f"{file_prefix}_{time_str}_summary.txt"
    summary_path = os.path.join(args.save_dir, "summary", file_name)
    os.makedirs(os.path.join(args.save_dir, "summary"), exist_ok=True)
    os.makedirs(save_result_dir, exist_ok=True)
    save_log_dir = os.path.join(args.save_dir, "log")
    os.makedirs(save_log_dir, exist_ok=True)
    logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s',
                        handlers=[logging.FileHandler(os.path.join(save_log_dir,
                                                                   file_name.replace("_summary.txt",
                                                                                     "_logfile.log"))),
                                  logging.StreamHandler(sys.stdout)])

    main()