import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from data.semantic_dataset import SquadExplanationSemanticDataset, BoolSemanticDataset, WinoGrandeSemanticDataset, MCQSemanticDataset
from data.rep_dataset import RepDataset
from utils import train_linear_model, compute_ece, get_linear_results
from llm import load_llm

import sys

import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--llm", type=str, default="llama-7b")
    parser.add_argument("--dataset", type=str, default="squad")
    parser.add_argument("--inv_cdf_norm", action="store_true", default=False, help="Use inverse cdf normalization")
    parser.add_argument("--random", action="store_true", default=False, help="Use random prompts")
    parser.add_argument("--gpt_exp", action="store_true", default=False, help="Use GPT explanations")
    parser.add_argument("--gpt_state", action="store_true", default=False, help="Use GPT state prompts")
    args = parser.parse_args()
    
    if args.dataset == "WinoGrande" or args.dataset == "CommonsenseQA":
        
        if args.dataset == "WinoGrande":
            train_dataset = WinoGrandeSemanticDataset(args.llm, split="train", random=args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)
            test_dataset = WinoGrandeSemanticDataset(args.llm, split="test", random=args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)

            train_labels, train_post_conf = \
                train_dataset.labels, train_dataset.post_confs

            test_labels, test_post_conf, = \
                test_dataset.labels, test_dataset.post_confs

        elif args.dataset == "CommonsenseQA":
            train_dataset = MCQSemanticDataset("CommonsenseQA", args.llm, split="train", random=args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)
            test_dataset = MCQSemanticDataset("CommonsenseQA", args.llm, split="test", random=args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)

            train_labels, train_post_conf = \
                train_dataset.labels, train_dataset.option_probs
            test_labels, test_post_conf, = \
                test_dataset.labels, test_dataset.option_probs

    else:
        if args.dataset == "squad":
            dataset = SquadExplanationSemanticDataset(args.llm, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state, random=args.random)
        elif args.dataset == "BooIQ":
            dataset = BoolSemanticDataset("BooIQ", args.llm, args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)
        elif args.dataset == "HaluEval":
            dataset = BoolSemanticDataset("HaluEval", args.llm, args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)
        elif args.dataset == "ToxicEval":
            dataset = BoolSemanticDataset("ToxicEval", args.llm, args.random, gpt_exp=args.gpt_exp, gpt_state=args.gpt_state)

        train_labels, train_post_conf = \
            dataset.train_labels, dataset.train_post_confs
        
        test_labels, test_post_conf, = \
            dataset.test_labels, dataset.test_post_confs
    
    # get results for postconf
    acc, f1, ece, auroc = get_linear_results(train_post_conf, train_labels, test_post_conf, test_labels, seed=0, balanced=True)
    # results["postconf_acc"].append(acc)
    # results["postconf_f1"].append(f1)
    # results["postconf_ece"].append(ece)
    # results["postconf_auroc"].append(auroc)
    print("Semantic Sim Auroc:", auroc)