import torch
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import numpy as np
import argparse
from tqdm import tqdm
import pandas as pd
import torch
import re
import warnings


warnings.filterwarnings("ignore")

def print_gpu_memory():
    print(
        f"Allocated: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB | "
        f"Reserved: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB"
    )

TASK_NAME_MAPPING = {
    "stem": [
        "abstract_algebra",
        "anatomy",
        "astronomy",
        "college_biology",
        "college_chemistry",
        "college_computer_science",
        "college_mathematics",
        "college_physics",
        "computer_security",
        "conceptual_physics",
        "electrical_engineering",
        "elementary_mathematics",
        "high_school_biology",
        "high_school_chemistry",
        "high_school_computer_science",
        "high_school_mathematics",
        "high_school_physics",
        "high_school_statistics",
        "machine_learning",
    ],
    "Humanities": [
        "formal_logic",
        "high_school_european_history",
        "high_school_us_history",
        "high_school_world_history",
        "international_law",
        "jurisprudence",
        "logical_fallacies",
        "moral_disputes",
        "moral_scenarios",
        "philosophy",
        "prehistory",
        "professional_law",
        "world_religions",
    ],
    "other": [
        "business_ethics",
        "college_medicine",
        "human_aging",
        "management",
        "marketing",
        "medical_genetics",
        "miscellaneous",
        "nutrition",
        "professional_accounting",
        "professional_medicine",
        "virology",
        "global_facts",
        "clinical_knowledge",
    ],
    "social": [
        "econometrics",
        "high_school_geography",
        "high_school_government_and_politics",
        "high_school_macroeconomics",
        "high_school_microeconomics",
        "high_school_psychology",
        "human_sexuality",
        "professional_psychology",
        "public_relations",
        "security_studies",
        "sociology",
        "us_foreign_policy",
    ],
}
SUBJECTS = [v for vl in TASK_NAME_MAPPING.values() for v in vl]
choicesmmlu = ["A", "B", "C", "D"]
choicescommon = ["A", "B", "C", "D", "E"]

@torch.no_grad()
def extract_ans(sentence, dataset = 'mmlu'):
    if dataset == 'mmlu':
        try:
            pred_answers = re.findall(r'A|B|C|D', sentence)
            return pred_answers[0]
        except:
            return ""
    elif dataset == 'common':
        try:
            pred_answers = re.findall(r'A|B|C|D|E', sentence)
            return pred_answers[0]
        except:
            return ""
@torch.no_grad()
def format_common(line, include_answer=True):
    example = "Question: " + line["question"]
    for ind, choice in enumerate(choicescommon):
        ans = line['choices']['text'][ind]
        example += f'\n{choice}. {ans}'

    if include_answer:
        example += "\nAnswer: " + line["answerKey"] + "\n\n"
    else:
        example += "\nAnswer:"
    return example

@torch.no_grad()
def generate_few_shot_common(k, subject, dev_df):
    prompt = "The following are multiple choice questions (with answers)\n\n"
    if k == -1:
        k = dev_df.shape[0]
    for i in range(k):
        prompt += format_common(
            dev_df.iloc[i, :],
            include_answer=True,
        )
    return prompt

@torch.no_grad()
def format_mmlu(line, include_answer=True):
    example = "Question: " + line["question"]
    for choice in choicesmmlu:
        example += f'\n{choice}. {line[f"{choice}"]}'

    if include_answer:
        example += "\nAnswer: " + line["answer"] + "\n\n"
    else:
        example += "\nAnswer:"
    return example

@torch.no_grad()
def generate_few_shot_mmlu(k, subject, dev_df):
    def format_subject(subject):
        l = subject.split("_")
        s = ""
        for entry in l:
            s += " " + entry
        return s.strip()

    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )

    if k == -1:
        k = dev_df.shape[0]
    for i in range(k):
        prompt += format_mmlu(
            dev_df.iloc[i, :],
            include_answer=True,
        )
    return prompt

@torch.no_grad()
def get_output(tokenizer, model, inputs):

    input_ids = tokenizer(inputs, padding='longest')["input_ids"]
    input_ids = torch.tensor(input_ids, device=model.device)
    max_seq_len = 4096
    if input_ids.shape[1] > max_seq_len:
        input_ids = input_ids[:, input_ids.shape[1] - max_seq_len + 1 :]
    tokens = {"input_ids": input_ids}
    attention_mask = input_ids.ne(tokenizer.pad_token_id)


    gen_kwargs = {
        "max_new_tokens": 64,
        "do_sample": False,
        "repetition_penalty": 1.0, 
        "length_penalty": 1.0,
        "use_cache": True,
        "pad_token_id": tokenizer.pad_token_id,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
    }
    generated_tokens = model.generate(**gen_kwargs)
    decoded_pred = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return decoded_pred



@torch.no_grad()
def eval_common(
    model,
    tokenizer,
    subject_name,
    test_df,
    k=5,
    dev_df=None,
    few_shot=False,
    save_result_dir=None,
    batch_size=1,
    **kwargs,
):
    result = []
    score = []

    few_shot_prompt = (
        generate_few_shot_common(k, subject_name, dev_df) if few_shot else []
    )

    print(few_shot_prompt)
    
    
    idx_list = list(range(0, len(test_df), batch_size))
    for i in tqdm(idx_list):
        full_prompt_list = []
        answer_list = []
        for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
            question = format_common(row, include_answer=False)
            full_prompt = few_shot_prompt + question
            full_prompt_list.append(full_prompt)
            if 'answerKey' in row:
                answer_list.append(row['answerKey'])
        
        decoded_pred = get_output(tokenizer, model, full_prompt_list)
     
        for i in range(len(decoded_pred)):
            try:
                pred = decoded_pred[i].split(few_shot_prompt)[1].split("Answer:")[1]
            except:
                try:
                    print("commonqa fewshot wrong")
                    pred = decoded_pred[i].split("Answer:")[6]
                except:
                    print("no answer###", decoded_pred[i])
                    with open("noans.txt", 'a', encoding='utf-8') as file:
                        file.write(subject_name)
                        file.write(decoded_pred[i])
                        file.write("\n")
                    pred = decoded_pred[i]
                    
            pred = extract_ans(pred, dataset = "common")

            if answer_list != []:
                correct = 1 if pred == answer_list[i] else 0
                score.append(correct)
            result.append(pred)
        
    print(len(score))
    return score

@torch.no_grad()
def eval_mmlu(
    model,
    tokenizer,
    subject_name,
    test_df,
    k=0,
    dev_df=None,
    few_shot=False,
    save_result_dir=None,
    batch_size=1,
    **kwargs,
):
    result = []
    score = []

    few_shot_prompt = (
        generate_few_shot_mmlu(k, subject_name, dev_df) if few_shot else []
    )
   
    

    idx_list = list(range(0, len(test_df), batch_size))
    for i in tqdm(idx_list):
        full_prompt_list = []
        answer_list = []
        for row in test_df.iloc[i:i+batch_size].to_dict(orient='records'):
            question = format_mmlu(row, include_answer=False)
            full_prompt = few_shot_prompt + question
            full_prompt_list.append(full_prompt)
            if 'answer' in row:
                answer_list.append(row['answer'])
        decoded_pred = get_output(tokenizer, model, full_prompt_list)
        
        for i in range(len(decoded_pred)):
            try:
                pred = decoded_pred[i].split(few_shot_prompt)[1].split("Answer:")[1]
            except:
                try:
                    pred = decoded_pred[i].split("Answer:")[6]
                except:
                    print("no answer###", decoded_pred[i])
                    with open("noans.txt", 'a', encoding='utf-8') as file:
                        file.write(subject_name)
                        file.write("\n")
                        file.write(decoded_pred[i])
                        file.write("\n")
                    pred = decoded_pred[i]
                
            pred = extract_ans(pred)

            if answer_list != []:
                correct = 1 if pred == answer_list[i] else 0
                score.append(correct)
            result.append(pred)

    return score

@torch.no_grad()
def cal_mmlu(res):
    acc_sum_dict = dict()
    acc_norm_sum_dict = dict()
    cnt_dict = dict()
    acc_sum = 0.0
    cnt = 0
    hard_cnt = 0
    hard_acc_sum = 0.0

    for class_ in TASK_NAME_MAPPING.keys():
        acc_sum_dict[class_] = 0.0
        acc_norm_sum_dict[class_] = 0.0
        cnt_dict[class_] = 0.0

        for tt in TASK_NAME_MAPPING[class_]:
            acc_sum += sum(res[tt])
            cnt += len(res[tt])

            acc_sum_dict[class_] += sum(res[tt])
            cnt_dict[class_] += len(res[tt])

    print("\n\n\n", "total cnt:", cnt, "\n")
    for k in TASK_NAME_MAPPING.keys():
        if k in cnt_dict:
            print("%s ACC: %.2f " % (k, acc_sum_dict[k] / cnt_dict[k] * 100))
    print("AVERAGE ACC:%.2f " % (acc_sum / cnt * 100))
    return (acc_sum / cnt)

@torch.no_grad()
def cal_common(res):
    acc_sum = 0.0
    cnt = 0

    tt = 'comm'
    acc_sum += sum(res[tt])
    cnt += len(res[tt])

    print("\n\n\n", "total cnt:", cnt, "\n")
   
    print("AVERAGE ACC:%.2f " % (acc_sum / cnt * 100))
    return (acc_sum / cnt)

@torch.no_grad()
def mmlumain(model, tokenizer, args):
    dev_result = {}
    for subject_name in tqdm(SUBJECTS):
        print_gpu_memory()
        
        testpath = "./tasks/mmlu"
        # val_file_path = os.path.join(args.eval_data_path, 'val', f'{subject_name}_val.csv')
        dev_file_path = os.path.join(
            testpath, "dev", f"{subject_name}_dev.csv"
        )
        test_file_path = os.path.join(
            testpath, "test", f"{subject_name}_test.csv"
        )
        # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
        dev_df = pd.read_csv(
            dev_file_path, names=["question", "A", "B", "C", "D", "answer"]
        )
        test_df = pd.read_csv(
            test_file_path, names=["question", "A", "B", "C", "D", "answer"]
        )

        score = eval_mmlu(
            model,
            tokenizer,
            subject_name,
            test_df,
            dev_df=dev_df,
            k=5,
            few_shot=True,
            batch_size=args.batch_size
        )
        dev_result[subject_name] = score
    return cal_mmlu(dev_result)

@torch.no_grad()
def commonmain(model, tokenizer, args):
    subject_name = 'comm'

    dev_result = {}
    
    testpath = "./tasks/commonqa"

    dev_file_path = os.path.join(
        testpath, "validation-00000-of-00001.parquet"
    )
    train_file_path = os.path.join(
        testpath, "train-00000-of-00001.parquet"
    )
    # val_df = pd.read_csv(val_file_path, names=['question','A','B','C','D','answer'])
    dev_df = pd.read_parquet(
        dev_file_path
    )
    train_df = pd.read_parquet(
        train_file_path
    )

    score = eval_common(
        model,
        tokenizer,
        subject_name,
        dev_df,
        dev_df=train_df,
        k=5,
        few_shot=True,
        batch_size=args.batch_size
    )
    dev_result[subject_name] = score
    
    return cal_common(dev_result)
