import argparse
import os
import torch
import numpy as np
import pandas as pd
from utils.crop import crop
from transformers import AutoModelForCausalLM, AutoTokenizer

choices_4 = ["A", "B", "C", "D"]
choices_5 = ["A", "B", "C", "D", "E"]


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s


def format_example(df, idx, include_answer=True, num_choices=4):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    if num_choices == 4:
        for j in range(k):
            prompt += "\n{}. {}".format(choices_4[j], df.iloc[idx, j+1])
    else:
        for j in range(k):
            prompt += "\n{}. {}".format(choices_5[j], df.iloc[idx, j+1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt


def gen_prompt(train_df, subject=None, k=-1, num_choices=4):
    if subject != None:
        prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject))
    else:
        prompt = "The following are multiple choice questions (with answers).\n\n"
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i, num_choices)
    return prompt


@torch.no_grad()
def eval(args, subject, model, tokenizer, dev_df, test_df, num_choices=4):
    cors = []
    all_probs = []

    for i in range(test_df.shape[0]):
        # get prompt and make sure it fits
        k = args.ntrain
        prompt_end = format_example(test_df, i, include_answer=False, num_choices=num_choices)
        train_prompt = gen_prompt(dev_df, subject=subject, k=k, num_choices=num_choices)
        prompt = train_prompt + prompt_end

        while crop(prompt) != prompt:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject=subject, k=k)
            prompt = train_prompt + prompt_end

        if num_choices == 4:
            prompt += "Only provide the option of the correct answer (A, B, C or D) without any additional information.\n\n"
        elif num_choices == 5:
            prompt += "Only provide the option of the correct answer (A, B, C, D or E) without any additional information.\n\n"
        else:
            raise ValueError(f"Only support num_choices = 4/5, got{num_choices}.")

        inputs = tokenizer(prompt, return_tensors="pt").to(args.device)
        label = test_df.iloc[i, test_df.shape[1]-1]

        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=1,
            temperature=1,
            output_scores=True,
            return_dict_in_generate=True
        )
        logits = outputs.scores[0][0]

        if num_choices == 4:
            probs = (
                torch.nn.functional.softmax(
                    torch.tensor(
                        [
                            logits[tokenizer("A").input_ids[0]],
                            logits[tokenizer("B").input_ids[0]],
                            logits[tokenizer("C").input_ids[0]],
                            logits[tokenizer("D").input_ids[0]],
                        ]
                    ),
                    dim=0,
                )
                .detach()
                .cpu()
                .numpy()
            )
            pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]
        elif num_choices == 5:
            probs = (
                torch.nn.functional.softmax(
                    torch.tensor(
                        [
                            logits[tokenizer("A").input_ids[0]],
                            logits[tokenizer("B").input_ids[0]],
                            logits[tokenizer("C").input_ids[0]],
                            logits[tokenizer("D").input_ids[0]],
                            logits[tokenizer("E").input_ids[0]],
                        ]
                    ),
                    dim=0,
                )
                .detach()
                .cpu()
                .numpy()
            )
            pred = {0: "A", 1: "B", 2: "C", 3: "D", 4: "E"}[np.argmax(probs)]
        else:
            raise ValueError(f"Only support num_choices = 4/5, got{num_choices}.")

        cor = pred == label
        cors.append(cor)
        all_probs.append(probs)

    acc = np.mean(cors)
    cors = np.array(cors)

    all_probs = np.array(all_probs)
    if subject != None:
        print("Average accuracy {:.3f} - {}".format(acc, subject))
    else:
        print("Average accuracy {:.3f}".format(acc))

    return cors, acc, all_probs


def main(args):
    model = AutoModelForCausalLM.from_pretrained(os.path.join(args.llm_dir, args.model)).to(args.device)
    tokenizer = AutoTokenizer.from_pretrained(os.path.join(args.llm_dir, args.model))
    model.eval()
    data_dir = os.path.join(args.data_dir, args.benchmark)

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    if not os.path.exists(os.path.join(args.save_dir, args.benchmark, "results_{}".format(args.model))):
        os.mkdir(os.path.join(args.save_dir, args.benchmark, "results_{}".format(args.model)))

    print(args)

    if args.benchmark == "MMLU-Med":
        subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(data_dir, "test")) if "_test.csv" in f])

        all_cors = []

        for subject in subjects:
            dev_df = pd.read_csv(os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
            test_df = pd.read_csv(os.path.join(data_dir, "test", subject + "_test.csv"), header=None)

            cors, acc, probs = eval(args, subject, model, tokenizer, dev_df, test_df, num_choices=4)
            all_cors.append(cors)

            test_df["{}_correct".format(args.model)] = cors
            for j in range(probs.shape[1]):
                choice = choices_4[j]
                test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
            test_df.to_csv(os.path.join(args.save_dir, "results_{}".format(args.model), "{}.csv".format(subject)), index=None)

        weighted_acc = np.mean(np.concatenate(all_cors))
        print("Average accuracy: {:.3f}".format(weighted_acc))

    elif args.benchmark == "MedQA":
        dev_df = pd.read_csv(os.path.join(data_dir, "dev", "dev.csv"), header=None)[:args.ntrain]
        test_df = pd.read_csv(os.path.join(data_dir, "test", "test.csv"), header=None)

        cors, acc, probs = eval(args, None, model, tokenizer, dev_df, test_df, num_choices=4)
        all_cors.append(cors)

        test_df["{}_correct".format(args.model)] = cors
        for j in range(probs.shape[1]):
            choice = choices_4[j]
            test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
        test_df.to_csv(os.path.join(args.save_dir, args.benchmark, "results_{}".format(args.model), "test.csv"), index=None)

        acc = np.mean(cors)
        print("Average accuracy: {:.3f}".format(acc))

    elif args.benchmark == "MedMCQA":
        dev_df = pd.read_csv(os.path.join(data_dir, "dev", "dev.csv"), header=None)[:args.ntrain]
        test_df = pd.read_csv(os.path.join(data_dir, "test", "test.csv"), header=None)

        cors, acc, probs = eval(args, None, model, tokenizer, dev_df, test_df, num_choices=5)
        all_cors.append(cors)

        test_df["{}_correct".format(args.model)] = cors
        for j in range(probs.shape[1]):
            choice = choices_5[j]
            test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
        test_df.to_csv(os.path.join(args.save_dir, args.benchmark, "results_{}".format(args.model), "test.csv"), index=None)

        acc = np.mean(cors)
        print("Average accuracy: {:.3f}".format(acc))

    else:
        raise ValueError(f"Only support MMLU-Med, MedQA and MedMCQA, got{args.benchmark}.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)
    parser.add_argument("--device", "-v", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--llm_dir", "-l", type=str, default="model_hub")
    parser.add_argument("--data_dir", "-d", type=str, default="benchmarks")
    parser.add_argument("--benchmark", "-b", type=str, default="MMLU-Med",
                        choices=["MMLU-Med", "MedQA", "MedMCQA"])
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument(
        "--model",
        "-m",
        type=str,
        default="Qwen2-0.5B-Instruct",
    )
    args = parser.parse_args()
    main(args)