from transformers import AutoTokenizer
from datasets import Dataset
from models import AvaModel
import os
import pickle
import click
from tqdm import tqdm
from utils import NoPrint
from utils import sp_auac, sp_auroc
import numpy as np
from sklearn import metrics
import json

def score_auroc(preds_scores, golds, mask_labels = []):
    preds, confidences = [int(x) for x in preds_scores["pred"]], [x for x in preds_scores["maxprob"]]
    preds = [x if x not in mask_labels else float("inf") for x in golds]
    matched_preds = list(zip(preds, confidences, golds))
    sorted_preds = sorted(matched_preds, key=lambda x: -x[1])
    corrects = np.array([int(p[0] == p[2]) for p in sorted_preds])
    confidences = [p[1] for p in sorted_preds]
    auc_score = metrics.roc_auc_score(corrects, confidences)
    return auc_score

def score_auac(preds_scores, labels):
    predictions, confidences = [int(x) for x in preds_scores["pred"]], [x for x in preds_scores["maxprob"]]
    corrects = (np.array(predictions) == np.array(labels))
    _, corrects = list(zip(*sorted(zip(confidences, corrects), reverse=True, key=lambda x: x[0]))) # Sort by conf
    x = np.arange(1, len(corrects) + 1) 
    cumulative_accs = np.cumsum(corrects) / x
    return metrics.auc(x / len(x), cumulative_accs)


def top1_acc(preds_scores, golds, ood_label):
    """ Compute ID Accuracy """
    preds = [int(x) for x in preds_scores["pred"]]
    corrects = [pred == gold for pred, gold in zip(preds, golds)]
    num_id_examples = len([x for x in golds if x != ood_label])
    return sum(corrects)/num_id_examples

@click.command()
@click.argument("model_path")
@click.argument("pkl_path")
@click.argument("mode")
@click.option("--train_dataset")
@click.option("--valid_dataset")
@click.option("--test_dataset")
@click.option("--use_cache", default=False)
def main(**kwargs):
    run(**kwargs)
    
def run(model_path, pkl_path, mode, train_dataset, valid_dataset, test_dataset, use_cache):
    train_dataset = Dataset.load_from_disk(train_dataset)
    train_labels = set(train_dataset["label"])
    valid_dataset = Dataset.load_from_disk(valid_dataset)
    test_dataset = Dataset.load_from_disk(test_dataset)
    valid_dataset, valid_golds = valid_dataset["text"], valid_dataset["label"]
    test_dataset, test_golds = test_dataset["text"], test_dataset["label"]

    gold_labels = {"validation": valid_golds, "test": test_golds}
    if not use_cache or not os.path.exists(pkl_path):
        preds_confs = {"validation": [], "test": []}
        with NoPrint():
            ava_model = AvaModel(model_path, cuda=True)
        #valid_pred, valid_confs_maxprob = ava_model.batched_call(valid_dataset, metric="maxprob")
        test_pred, test_confs_maxprob = ava_model.batched_call(test_dataset, metric="maxprob")
        #preds_confs["validation"].append({"sequence": valid_dataset, "pred": valid_pred, "maxprob": valid_confs_maxprob})
        preds_confs["test"].append({"sequence": test_dataset, "pred": test_pred, "maxprob": test_confs_maxprob})
        with open(pkl_path, "wb") as f:
            pickle.dump(preds_confs, f)
    else:
        with open(pkl_path, "rb") as f:
            preds_confs = pickle.load(f)

    ood_labels = list(set(test_golds).difference(set(train_labels)))
    
    statistics = ["id-acc", "auroc", "auac"]
    for i, ood_label in enumerate(ood_labels):
        holdout_labels = ood_labels[:i] + ood_labels[i + 1:]
        results = {}
        # Evaluate all seeds
        for elem in preds_confs["test"]:
            results["id-acc"] = results.get("id-acc", []) + \
                    [top1_acc(elem, gold_labels["test"], ood_label=ood_label)]
            results["auroc"] = results.get("auroc", []) + \
                    [score_auroc(elem, gold_labels["test"], mask_labels=[ood_label])]
            results["auac"] = results.get("auac", []) + \
                    [score_auac(elem, gold_labels["test"])]

        print(f"Generalizing from {''.join([str(x) for x in train_labels])} to {str(ood_label)}:")
        for statistic in statistics:
            print(statistic, [x for x in results[statistic]][0])
    return {ood_label: {statistic: [x for x in results[statistic]][0] for statistic in statistics} for ood_label in ood_labels}



if __name__ == "__main__":
    main()
