from asyncore import write
import os
import json
import csv

import numpy as np
import scipy.stats
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
import calibration as cal
from tqdm import tqdm

from utils import check_dir

def load_output(path):
    """Loads output file, wraps elements in tensor."""

    with open(path) as f:
        elems = [json.loads(l.rstrip()) for l in f]
        for elem in elems:
            elem['true'] = torch.tensor(elem['true']).long()
            elem['logits'] = torch.tensor(elem['logits']).float()
        return elems

def get_bucket_scores(y_score):
    """
    Organizes real-valued posterior probabilities into buckets.
    For example, if we have 10 buckets, the probabilities 0.0, 0.1,
    0.2 are placed into buckets 0 (0.0 <= p < 0.1), 1 (0.1 <= p < 0.2),
    and 2 (0.2 <= p < 0.3), respectively.
    """

    bucket_values = [[] for _ in range(ece_buckets)]
    bucket_indices = [[] for _ in range(ece_buckets)]
    for i, score in enumerate(y_score):
        for j in range(ece_buckets):
            if score < float((j + 1) / ece_buckets):
                break
        bucket_values[j].append(score)
        bucket_indices[j].append(i)
    return (bucket_values, bucket_indices)


def get_bucket_confidence(bucket_values):
    """
    Computes average confidence for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in bucket_values
    ]


def get_bucket_accuracy(bucket_values, y_true, y_pred):
    """
    Computes accuracy for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    per_bucket_correct = [
        [int(y_true[i] == y_pred[i]) for i in bucket]
        for bucket in bucket_values
    ]
    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in per_bucket_correct
    ]


def calculate_error(n_samples, bucket_values, bucket_confidence, bucket_accuracy):
    """
    Computes several metrics used to measure calibration error:
        - Expected Calibration Error (ECE): \sum_k (b_k / n) |acc(k) - conf(k)|
        - Maximum Calibration Error (MCE): max_k |acc(k) - conf(k)|
        - Total Calibration Error (TCE): \sum_k |acc(k) - conf(k)|
    """

    assert len(bucket_values) == len(bucket_confidence) == len(bucket_accuracy)
    assert sum(map(len, bucket_values)) == n_samples

    expected_error, max_error, total_error = 0., 0., 0.
    for (bucket, accuracy, confidence) in zip(
        bucket_values, bucket_accuracy, bucket_confidence
    ):
        if len(bucket) > 0:
            delta = abs(accuracy - confidence)
            expected_error += (len(bucket) / n_samples) * delta
            max_error = max(max_error, delta)
            total_error += delta
    return (expected_error * 100., max_error * 100., total_error * 100.)

def create_one_hot(n_classes):
    """Creates one-hot label tensor."""
    one_hot = torch.full((n_classes,), 0.).float()
    return one_hot


def cross_entropy(output, target, n_classes):
    """
    Computes cross-entropy with KL divergence from predicted distribution
    and true distribution, specifically, the predicted log probability
    vector and the true one-hot label vector.
    """

    model_prob = create_one_hot(n_classes)
    return F.kl_div(output, model_prob, reduction='sum').item()

id_ood_dict = {
    "snli": "mnli",
    "qqp": "TwitterPPDB",
    "swag": "hellaswag"
}

seeds = [13, 21, 42, 87, 100]
eval_split = 'val'
id_dataset_name = "snli"
ood_dataset_name = id_ood_dict[id_dataset_name]
conf_dir_name = "./outputs/conf/"
model_name_list = [ "roberta-base",]


epoch_num = 1
output_dir = "./outputs/csv/"
header_list = ["Epoch", "ID Acc", "ID ECE", "OOD Acc", "OOD ECE"]

ece_buckets = 10
for model_name in model_name_list:
    total_list = []
    for seed in tqdm(seeds):
        id_conf_dir = os.path.join(conf_dir_name, id_dataset_name, eval_split, f'{model_name}_seed={seed}')
        ood_conf_dir = os.path.join(conf_dir_name, ood_dataset_name, eval_split, f'{model_name}_seed={seed}')
        seed_res_list = []
        for epoch in range(epoch_num):
            epoch_res_list = [epoch + 1]
            # ID
            # id_conf_path = os.path.join(id_conf_dir, f'epoch={epoch}', 'res.json')
            id_conf_path = os.path.join(id_conf_dir, 'res.json')
            elems = load_output(id_conf_path)
            n_classes = len(elems[0]['logits'])
            labels = [elem['true'] for elem in elems]
            preds = [elem['pred'] for elem in elems]
            log_probs = [F.log_softmax(elem['logits'], 0) for elem in elems]
            probs = [prob.exp() for prob in log_probs]
            confs = [prob.exp().max().item() for prob in log_probs]
            nll = [
                cross_entropy(log_prob, label, n_classes)
                for log_prob, label in zip(log_probs, labels)
            ]
            bucket_values, bucket_indices = get_bucket_scores(confs)
            bucket_confidence = get_bucket_confidence(bucket_values)
            bucket_accuracy = get_bucket_accuracy(bucket_indices, labels, preds)

            accuracy = accuracy_score(labels, preds) * 100.
            avg_conf = np.mean(confs) * 100.
            avg_nll = np.mean(nll)
            expected_error, max_error, total_error = calculate_error(
                len(elems), bucket_values, bucket_confidence, bucket_accuracy
            )
            epoch_res_list += [float(accuracy), float(expected_error)]

            # OOD
            # ood_conf_path = os.path.join(ood_conf_dir, f'epoch={epoch}', 'res.json')
            ood_conf_path = os.path.join(ood_conf_dir, 'res.json')
            elems = load_output(ood_conf_path)
            n_classes = len(elems[0]['logits'])
            labels = [elem['true'] for elem in elems]
            preds = [elem['pred'] for elem in elems]
            log_probs = [F.log_softmax(elem['logits'], 0) for elem in elems]
            probs = [prob.exp() for prob in log_probs]
            confs = [prob.exp().max().item() for prob in log_probs]
            nll = [
                cross_entropy(log_prob, label, n_classes)
                for log_prob, label in zip(log_probs, labels)
            ]
            bucket_values, bucket_indices = get_bucket_scores(confs)
            bucket_confidence = get_bucket_confidence(bucket_values)
            bucket_accuracy = get_bucket_accuracy(bucket_indices, labels, preds)

            accuracy = accuracy_score(labels, preds) * 100.
            avg_conf = np.mean(confs) * 100.
            avg_nll = np.mean(nll)
            expected_error, max_error, total_error = calculate_error(
                len(elems), bucket_values, bucket_confidence, bucket_accuracy
            )
            epoch_res_list += [float(accuracy), float(expected_error)]
            seed_res_list.append(epoch_res_list)
        total_list.append(seed_res_list)
        
    total_list = np.array(total_list)
    avg = np.mean(total_list, axis=0)
    std = np.std(total_list, axis=0)
    # se = scipy.stats.sem(total_list, axis=0)
    # std = se * scipy.stats.t.ppf((1 + 0.9) / 2., len(seeds) - 1)
    # std = np.nan_to_num(std)
    output_path = os.path.join(output_dir, f'{id_dataset_name}-{ood_dataset_name}', eval_split, f'{model_name}')
    check_dir(output_path)
    output_path = os.path.join(output_path, "res.csv")
    with open(output_path, mode='w', encoding='utf-8-sig', newline="") as f:
        writer = csv.writer(f)
        writer.writerow(header_list)
        writer.writerows(avg)
        writer.writerows(std)