import argparse
import json
import torch
import numpy as np
from transformers import AutoTokenizer,AutoModelForCausalLM

tasks = ['sst2', 'qqp', 'mnli', 'qnli', 'mnli-mm', 'rte']

task_to_keys = {
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
}


def format_example(task_name, question, origin=False, k=5):
    prompt = ""
    for i in range(k):
        if task_name in ["mnli","mnli-mm"]:
            prompt += gen_prompt(task_name, question[i], origin=origin)
            if question[i]['label'] == 0:
                prompt += " {}\n\n".format("A. yes")
            elif question[i]['label'] == 1:
                prompt += " {}\n\n".format("B. maybe")
            else:
                prompt += " {}\n\n".format("C. no")
        elif task_name in  ['qnli','rte','qqp']:
            prompt += gen_prompt(task_name, question[i], origin=origin)
            
            if question[i]['label'] == 0:
                prompt += " {}\n\n".format("A. yes")
            else:
                prompt += " {}\n\n".format("B. no")
        elif task_name =='sst2':
            prompt += gen_prompt(task_name, question[i], origin=origin)
            
            if question[i]['label'] == 0:
                prompt += " {}\n\n".format("A. positive")
            else:
                prompt += " {}\n\n".format("B. negative")
        else:
            raise ValueError("Unsupported task:", task_name)
    return prompt

def gen_prompt(task_name, question, origin=False):
    if task_name == "mnli":
        prompt = "Please identify whether the premise entails the hypothesis. The answer should be exactly 'A. yes', 'B. maybe' or 'C. no'\n"
        if origin and 'original_premise' in question.keys():
            prompt += "Premise: " + question['original_premise']
        else:
            prompt += "Premise: " + question['premise']
        prompt += "\nHypothesis: " + question['hypothesis']
        prompt += "\nAnswer: "
    elif task_name == "mnli-mm":
        prompt = "Please identify whether the premise entails the hypothesis. The answer should be exactly 'A. yes', 'B. maybe' or 'C. no'\n"
        prompt += "Premise: " + question['premise']
        if origin and 'original_hypothesis' in question.keys():
            prompt += "\nHypothesis: " + question['original_hypothesis']
        else:
            prompt += "\nHypothesis: " + question['hypothesis']
        prompt += "\nAnswer: "
    elif task_name == 'qnli':
        prompt = "Please identify whether the sentence answers the question. The answer should be exactly 'A. yes' or 'B. no'\n"
        if origin and 'original_question' in question.keys():
            prompt += "Question: " + question['original_question']
        else:
            prompt += "Question: " + question['question']
        prompt += "\nSentence: " + question['sentence']
        prompt += "\nAnswer: "
    elif task_name == 'rte':
        prompt = "Please identify whether the sentence1 entails the sentence2. The answer should be exactly 'A. yes' or 'B. no'\n"
        if origin and 'original_sentence1' in question.keys():
            prompt += "Sentence 1: " + question['original_sentence1']
        else:
            prompt += "Sentence 1: " + question['sentence1']
        prompt += "\nSentence 2: " + question['sentence2']
        prompt += "\nAnswer: "
    elif task_name == 'qqp':
        prompt = "Please identify whether Question 1 has the same meaning as Question 2. The answer should be exactly 'A. yes' or 'B. no'\n"
        # prompt = "Please identify whether the question1 entails the question2. The answer should be exactly 'A. yes' or 'B. no'\n\n"
        if origin and 'original_question1' in question.keys():
            prompt += "Question 1: " + question['original_question1']
        else:
            prompt += "Question 1: " + question['question1']
        prompt += "\nQuestion 2: " + question['question2']
        prompt += "\nAnswer: "
    elif task_name =='sst2':
        prompt = "For each snippet of text, label the sentiment of the text as positive or negative. The answer should be exactly 'A. positive' or 'B. negative'\n"
        # prompt = "Please identify whether the sentence is positive or negative. The answer should be exactly 'A. positive' or 'B. negative'\n\n"
        if origin and 'original_sentence' in question.keys():
            prompt += "Sentence: " + question['original_sentence']
        else:
            prompt += "Sentence: " + question['sentence']
        prompt += "\nAnswer: "
    else:
        raise ValueError("Unsupported task:", task_name)
    
    return prompt


def eval(model, tokenizer, dataset,  ntrain, test_origin):
    cors = []
    for task_name in tasks:
        task_cors = []
        test = dataset[task_name]
        for i in range(ntrain, len(test)):
            prompt_end = gen_prompt(task_name, test[i], origin=test_origin)
            example = format_example(task_name, test, origin=test_origin, k=ntrain)
            prompt = example + prompt_end
            input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
   
            label = test[i]["label"]
            if task_name in ["mnli", "mnli-mm"]:
                logits = model(input_ids=input_ids).logits[:,-1].flatten()
                probs = (
                    torch.nn.functional.softmax(
                        torch.tensor(
                            [
                                logits[tokenizer("A").input_ids[-1]],
                                logits[tokenizer("B").input_ids[-1]],
                                logits[tokenizer("C").input_ids[-1]],
                            ]
                        ).float(),
                        dim=0,
                    )
                    .detach()
                    .cpu()
                    .to(torch.float32)
                    .numpy()
                )
                pred = np.argmax(probs)
            else:
                logits = model(input_ids=input_ids).logits[:,-1].flatten()
                task_mappings = {
                    'qqp': {0: 1, 1: 0},
                    'sst2': {0: 1, 1: 0},
                    'qnli': {0:0, 1: 1},
                    'rte': {0:1, 1: 0}
                    }
                probs = (
                    torch.nn.functional.softmax(
                        torch.tensor(
                            [
                                logits[tokenizer("A").input_ids[-1]],
                                logits[tokenizer("B").input_ids[-1]]
                            ]
                        ).float(),
                        dim=0,
                    )
                    .detach()
                    .cpu()
                    .to(torch.float32)
                    .numpy()
                )
                task_map = task_mappings[task_name]
                pred = task_map[np.argmax(probs)]

            cor = pred == label
            task_cors.append(cor)
            cors.append(cor)
        task_acc = np.mean(task_cors)
        print("Accuracy {:.4f} - Task {}".format(task_acc, task_name))
    
    acc = np.mean(cors)
    print("Average accuracy {:.4f}".format(acc))

def eval_generate(model, tokenizer, dataset,  ntrain, test_origin):
    cors = []
    for task_name in tasks:
        task_cors = []
        test = dataset[task_name]
        for i in range(len(test)):
            prompt = gen_prompt(task_name, test[i], origin=test_origin)
            input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")

            label = test[i]["label"]
            if task_name in ["mnli", "mnli-mm"]:
                logits = model(input_ids=input_ids).logits[:,-1].flatten()
                probs = (
                    torch.nn.functional.softmax(
                        torch.tensor(
                            [
                                logits[tokenizer("A").input_ids[-1]],
                                logits[tokenizer("B").input_ids[-1]],
                                logits[tokenizer("C").input_ids[-1]],
                            ]
                        ).float(),
                        dim=0,
                    )
                    .detach()
                    .cpu()
                    .to(torch.float32)
                    .numpy()
                )
                pred = np.argmax(probs)
            else:
                logits = model(input_ids=input_ids).logits[:,-1].flatten()
                task_mappings = {
                    'qqp': {0: 1, 1: 0},
                    'sst2': {0: 1, 1: 0},
                    'qnli': {0:0, 1: 1},
                    'rte': {0:1, 1: 0}
                    }
                probs = (
                    torch.nn.functional.softmax(
                        torch.tensor(
                            [
                                logits[tokenizer("A").input_ids[-1]],
                                logits[tokenizer("B").input_ids[-1]]
                            ]
                        ).float(),
                        dim=0,
                    )
                    .detach()
                    .cpu()
                    .to(torch.float32)
                    .numpy()
                )
                task_map = task_mappings[task_name]
                pred = task_map[np.argmax(probs)]
                cor = pred == label
                task_cors.append(cor)
                cors.append(cor)
        task_acc = np.mean(task_cors)
        print("Accuracy {:.4f} - Task {}".format(task_acc, task_name))
    
    acc = np.mean(cors)
    print("Average accuracy {:.4f}".format(acc))



import json


def eval_advglu(model, tokenizer, ntrain=0, data_file='data/adv_glue/dev_ann.json', test_origin=False):
    """
    Main function to load dataset and evaluate the model.

    Args:
        model: Loaded model.
        tokenizer: Tokenizer paired with the model.
        ntrain (int): Number of training samples, default is 0.
        data_file (str): Input JSON data file, default is 'data/adv_glue/dev_ann.json'.
        test_origin (bool): Whether to test on the original GLUE data, default is False.
    """
    with open(data_file) as f:
        dataset = json.load(f)

    eval(model, tokenizer, dataset, ntrain, test_origin)  # Evaluate with the given dataset and parameters



