import argparse
import json
import os
import sys

import numpy as np
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

from utils import check_dir

task = "wikitext-103-raw-v1"
model_task = "qqp"
seed=42
mlm_prob=0.15
models = [f"roberta-base_pt_{model_task}_mlm_prob={mlm_prob}", f"roberta-base_ft_{model_task}_mlm_prob={mlm_prob}", f"roberta-base_lora_{model_task}_mlm_prob={mlm_prob}", f"roberta-base_dvaed_{model_task}_mlm_prob={mlm_prob}", f"roberta-base_dvaep_{model_task}_mlm_prob={mlm_prob}"]
label_list = ["Pre-trained", "Full-FT", "LoRA", "DVAE-D", "DVAE-P"]
marker_list = ['*' ,'o', 'v', 's', 'D']
buckets = 10
path_list = []
for model_name in models:
    path_list.append(f'./outputs/mlm/{task}/{model_name}/res.json')
output_dir = './outputs/figs/rd'

def load_output(path):
    """Loads output file, wraps elements in tensor."""

    with open(path) as f:
        elems = [json.loads(l.rstrip()) for l in f]
        for elem in elems:
            elem['true'] = torch.tensor(elem['true']).long()
            elem['conf'] = torch.tensor(elem['conf']).float()
        return elems

def get_bucket_scores(y_score):
    """
    Organizes real-valued posterior probabilities into buckets.
    For example, if we have 10 buckets, the probabilities 0.0, 0.1,
    0.2 are placed into buckets 0 (0.0 <= p < 0.1), 1 (0.1 <= p < 0.2),
    and 2 (0.2 <= p < 0.3), respectively.
    """

    bucket_values = [[] for _ in range(buckets)]
    bucket_indices = [[] for _ in range(buckets)]
    for i, score in enumerate(y_score):
        for j in range(buckets):
            if score < float((j + 1) / buckets):
                break
        bucket_values[j].append(score)
        bucket_indices[j].append(i)
    return (bucket_values, bucket_indices)


def get_bucket_confidence(bucket_values):
    """
    Computes average confidence for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in bucket_values
    ]


def get_bucket_accuracy(bucket_values, y_true, y_pred):
    """
    Computes accuracy for each bucket. If a bucket does
    not have predictions, returns -1.
    """

    per_bucket_correct = [
        [int(y_true[i] == y_pred[i]) for i in bucket]
        for bucket in bucket_values
    ]
    return [
        np.mean(bucket)
        if len(bucket) > 0 else -1.
        for bucket in per_bucket_correct
    ]

def calculate_error(n_samples, bucket_values, bucket_confidence, bucket_accuracy):
    """
    Computes several metrics used to measure calibration error:
        - Expected Calibration Error (ECE): \sum_k (b_k / n) |acc(k) - conf(k)|
        - Maximum Calibration Error (MCE): max_k |acc(k) - conf(k)|
        - Total Calibration Error (TCE): \sum_k |acc(k) - conf(k)|
    """

    assert len(bucket_values) == len(bucket_confidence) == len(bucket_accuracy)
    assert sum(map(len, bucket_values)) == n_samples

    expected_error, max_error, total_error = 0., 0., 0.
    for (bucket, accuracy, confidence) in zip(
        bucket_values, bucket_accuracy, bucket_confidence
    ):
        if len(bucket) > 0:
            delta = abs(accuracy - confidence)
            expected_error += (len(bucket) / n_samples) * delta
            max_error = max(max_error, delta)
            total_error += delta
    return (expected_error * 100., max_error * 100., total_error * 100.)

dict_list = []
for output_path in path_list:
    dict_list.append(load_output(output_path))

palette = plt.get_cmap('tab10')
font1 = {
'family' : 'sans-serif',
'weight' : 'normal',
'size'   : 14,
}
plt.style.use('seaborn-paper')
fig=plt.figure(figsize=(6,6))
ax = fig.add_subplot(1, 1, 1)
true_x = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
true_y = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]


for idx, elems in enumerate(dict_list):
    # n_classes = len(elems[0]['logits'])

    labels = [elem['true'] for elem in elems]
    preds = [elem['pred'] for elem in elems]
    confs = [elem['conf'] for elem in elems]

    bucket_values, bucket_indices = get_bucket_scores(confs)
    bucket_confidence = get_bucket_confidence(bucket_values)
    bucket_accuracy = get_bucket_accuracy(bucket_indices, labels, preds)
    
    for i, val in enumerate(bucket_confidence):
        if val != -1:
            break
    ece = calculate_error(len(elems), bucket_values, bucket_confidence, bucket_accuracy)
    ax.plot(bucket_confidence[i:], bucket_accuracy[i:], linewidth=2.0, marker=marker_list[idx], markersize=8, color=palette(idx), label=f'{label_list[idx]}, ' + f'(ECE={ece[0]:.2f})')
ax.plot(true_x, true_y, linewidth=3, linestyle='--', color='black', label='Zero Error')

ax.set_xlabel("Confidence", fontsize=20)
ax.set_ylabel("Accuracy", fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
leg = ax.legend(loc='upper left', prop=font1, frameon=True, fancybox=True, framealpha=0.4, borderpad=0.3)
leg.get_frame().set_linewidth(1.5)
check_dir(output_dir)
plt.savefig(os.path.join(output_dir, f"{task}-{model_task}_rd_{mlm_prob}.png"), bbox_inches='tight')