from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

from scipy.stats import norm

from data.explanation_dataset import SummarizationDataset
from data.rep_dataset import RepDataset
from utils import train_linear_model, compute_ece, normalize_data, train_linear_regressor, train_linear_regressor_pt
from llm import load_llm

import sys

import argparse

if __name__ == "__main__":

    # set random seed
    np.random.seed(0)
    torch.manual_seed(0)

    parser = argparse.ArgumentParser()
    parser.add_argument("--llm", type=str, default="llama-7b")
    parser.add_argument("--inv_cdf_norm", action="store_true", default=False, help="Use inverse cdf normalization")
    parser.add_argument("--dataset", type=str, default="cnn", help="Dataset to use")
    parser.add_argument("--gpt_exp", action="store_true", default=False, help="Use GPT explanations")
    parser.add_argument("--gpt_state", action="store_true", default=False, help="Use GPT state prompts")
    args = parser.parse_args()

    dataset = SummarizationDataset(args.llm, dataset=args.dataset, gpt_explanations=args.gpt_exp, gpt_state=args.gpt_state)
    
    train_data, train_labels, train_log_probs = \
        dataset.train_data, dataset.train_labels, dataset.train_log_probs
    
    test_data, test_labels, test_log_probs, = \
        dataset.test_data, dataset.test_labels, dataset.test_log_probs
    train_logits, train_pre_conf, train_post_conf = dataset.train_logits, dataset.train_pre_confs, dataset.train_post_confs
    test_logits, test_pre_conf, test_post_conf = dataset.test_logits, dataset.test_pre_confs, dataset.test_post_confs

    rep_dataset = RepDataset(args.dataset, args.llm)
    train_rep = rep_dataset.train_rep
    test_rep = rep_dataset.test_rep

    results = {
        "logprob_mse": [],
        "logits_mse": [],
        "preconf_mse": [],
        "postconf_mse": [],
        "exp_mse": [],
        "exp_all_mse": [],
        "rep_mse": [],
        "pca_mse": [],
        "pca_exp": [],
        "pca_rep": [],
    }

    seeds = range(5)

    # unsqueeze 2nd dim of 1d outputs
    train_pre_conf = train_pre_conf.reshape(-1, 1)
    test_pre_conf = test_pre_conf.reshape(-1, 1)
    train_post_conf = train_post_conf.reshape(-1, 1)
    test_post_conf = test_post_conf.reshape(-1, 1)
    train_log_probs = train_log_probs.reshape(-1, 1)
    test_log_probs = test_log_probs.reshape(-1, 1)

    # print training data shape - weird artifact in gpt_exp where need to reshape from (40n,) to (n, 40)
    train_data = train_data.reshape(train_labels.shape[0], -1)
    test_data = test_data.reshape(test_labels.shape[0], -1)

    print(train_data.shape, train_labels.shape, train_log_probs.shape, train_pre_conf.shape, train_post_conf.shape, train_logits.shape)

    # reduce number of training data
    # train_sub = 100
    # train_data = train_data[:train_sub]
    # train_labels = train_labels[:train_sub]
    # train_log_probs = train_log_probs[:train_sub]
    # train_pre_conf = train_pre_conf[:train_sub]
    # train_post_conf = train_post_conf[:train_sub]
    # train_logits = train_logits[:train_sub]


    # standard z-score normalize all data with train mean and std
    # train_data, test_data = normalize_data(train_data, test_data)
    # train_log_probs, test_log_probs = normalize_data(train_log_probs, test_log_probs)
    # train_pre_conf, test_pre_conf = normalize_data(train_pre_conf, test_pre_conf)
    # train_post_conf, test_post_conf = normalize_data(train_post_conf, test_post_conf)
    # train_logits, test_logits = normalize_data(train_logits, test_logits)

    # print all data shapes
    # print(train_data.shape, train_labels.shape, train_log_probs.shape, train_pre_conf.shape, train_post_conf.shape, train_logits.shape)
    # print(test_data.shape, test_labels.shape, test_log_probs.shape, test_pre_conf.shape, test_post_conf.shape, test_logits.shape)
    
    # print(train_labels.mean(), test_labels.mean())   
    # print(train_labels, test_labels)
    # print("means of data dims", train_data.mean(axis=0), test_data.mean(axis=0))
    # print("mean log probs", train_log_probs.mean(), test_log_probs.mean())

    # sys.exit()

    for seed in seeds:
    
        # set random seed   
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # get results for logprob
        clf = train_linear_regressor(train_log_probs, train_labels, test_log_probs, test_labels, seed=seed)
        y_pred = clf.predict(test_log_probs)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["logprob_mse"].append(mse)

        # get results for preconf
        clf = train_linear_regressor(train_pre_conf, train_labels, test_pre_conf, test_labels, seed=seed)
        y_pred = clf.predict(test_pre_conf)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["preconf_mse"].append(mse)

        # get results for postconf
        clf = train_linear_regressor(train_post_conf, train_labels, test_post_conf, test_labels, seed=seed)
        y_pred = clf.predict(test_post_conf)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["postconf_mse"].append(mse)

        # get results for logits
        clf = train_linear_regressor(train_logits, train_labels, test_logits, test_labels, seed=seed)
        y_pred = clf.predict(test_logits)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["logits_mse"].append(mse)

        # get results for using PCA + logits
        from sklearn.decomposition import PCA
        pca = PCA(n_components=5)
        train_pca = pca.fit_transform(train_logits)
        test_pca = pca.transform(test_logits)
        clf = train_linear_regressor(train_pca, train_labels, test_pca, test_labels, seed=seed)
        y_pred = clf.predict(test_pca)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["pca_mse"].append(mse)

        # get results for exp
        clf = train_linear_regressor(train_data, train_labels, test_data, test_labels, seed=seed)
        y_pred = clf.predict(test_data)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["exp_mse"].append(mse)

        # print train test loss for exp
        mse = ((test_labels - y_pred) ** 2).mean()
        # print("exp test mse", mse)
        y_pred = clf.predict(train_data)
        mse = ((train_labels - y_pred) ** 2).mean()
        # print("exp train mse", mse)

        # get reuslts for exp_all
        train_data_all = np.concatenate([train_data, train_log_probs, train_pre_conf, train_post_conf], axis=1)
        test_data_all = np.concatenate([test_data, test_log_probs, test_pre_conf, test_post_conf], axis=1)
        clf = train_linear_regressor(train_data_all, train_labels, test_data_all, test_labels, seed=seed)
        y_pred = clf.predict(test_data_all)
        mse = ((test_labels - y_pred) ** 2).mean()
        results["exp_all_mse"].append(mse)
    
    # compute means
    results = {k: np.mean(v) for k, v in results.items()}
    results = {k: round(v, 5) for k, v in results.items()}
    # for k in ["logits_mse", "logprob_mse", "preconf_mse", "postconf_mse", "exp_mse", "exp_all_mse", "pca_mse", "pca_exp", "rep_mse", "pca_rep"]:
    for k in ["logits_mse", "rep_mse", "logprob_mse", "preconf_mse", "postconf_mse", "exp_mse", "exp_all_mse"]:
        print(k, results[k])    
