"""
Calibration Results for Quantile Network on CIFAR10C

- Results are generated using a pretrained-quantile-network
- As a baseline we also use the standard pretrained-network
"""
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from tqdm import tqdm
import pickle

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from quantile_datamodule import get_base_datasets
from dataset_distorted import DistortedCIFAR10
import config

import calibration as cal

from utils_train import get_base_datasets, get_pretrained_model

# To stop lightning from complaining about precision
torch.set_float32_matmul_precision("high")

distortions_paper = [
    "gaussian_noise", "shot_noise", "impulse_noise","defocus_blur",
    "glass_blur","motion_blur", "zoom_blur", "snow", "frost", 
    "fog", "brightness","contrast", "elastic_transform", "pixelate", 
    "jpeg_compression"
]

checkpoints_pretrained = {
"resnet34_cifar10": config.CHECKPOINT_DIR + "quantile_resnet34_cifar10.ckpt",
"resnet34_cifar100": config.CHECKPOINT_DIR + "quantile_resnet34_cifar100.ckpt",
"resnet34_svhn": config.CHECKPOINT_DIR + "quantile_resnet34_svhn.ckpt",
"densenet_cifar10": config.CHECKPOINT_DIR + "quantile_densenet_cifar10.ckpt",
"densenet_svhn": config.CHECKPOINT_DIR + "quantile_densenet_svhn.ckpt",
"densenet_cifar100": config.CHECKPOINT_DIR + "quantile_densenet_cifar100.ckpt",
}

device = torch.device("cuda:1")

def establish_pretrained_model_quantile(name_base_model: str):    
    """
    """
    model, num_classes, size_dataset = get_pretrained_model(name_base_model)
    model.conv1 = nn.Conv2d(
            in_channels=model.conv1.in_channels + 1,
            out_channels=model.conv1.out_channels,
            kernel_size=model.conv1.kernel_size,
            stride=model.conv1.stride,
            padding=model.conv1.padding,
            bias=False,
        )
    state_dict = torch.load(checkpoints_pretrained[name_base_model])["state_dict"]
    del state_dict['quantiles_list']
    model.load_state_dict(dict([(name.replace('backbone.',''), param) for name,param in state_dict.items()]),strict=True)

    model.eval()

    # Make all the parameters in the backbone non-trainable
    for param in model.parameters():
        param.requires_grad_(False)

    return model

def establish_pretrained_model_baseline(name_base_model: str):    
    """
    """
    model, num_classes, size_dataset = get_pretrained_model(name_base_model)

    # Make all the parameters in the backbone non-trainable
    for param in model.parameters():
        param.requires_grad_(False)

    return model

def get_quantile_probabilities(model, dataset):
    """
    - assumed that model outputs logits
    """
    model = model.to(device)
    dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=5, pin_memory=True)
    probs = []
    labels = []
    for batch in tqdm(dataloader, total=len(dataloader)):
        x, label = batch
        x = x.to(device)
        with torch.no_grad():
            probs_tmp = []
            quant_val = (torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])).float()).to(device)
            x = torch.cat([x, quant_val], dim=1)
            for tau in torch.linspace(0,1,102)[1:-1]:
                x[:,-1,:,:] = tau
                logits = model(x)
                probs_tmp.append((logits>0).float())
            probs_tmp = torch.stack(probs_tmp, dim=0)
            probs_tmp = torch.mean(probs_tmp, dim=0)
            probs.append(probs_tmp.cpu().numpy())
            labels.append(label.numpy())
    probs = np.concatenate(probs, axis=0)
    labels = np.concatenate(labels, axis=0)
    return probs, labels

def get_baseline_probabilities(model, dataset):
    """
    - assumed that model outputs logits
    """
    model = model.to(device)
    dataloader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=5, pin_memory=True)
    probs = []
    labels = []
    for batch in tqdm(dataloader, total=len(dataloader)):
        x, label = batch
        x = x.to(device)
        with torch.no_grad():
            logits = model(x)
            probs_tmp = torch.softmax(logits, dim=1)
            probs.append(probs_tmp.cpu().numpy())
            labels.append(label.numpy())
    probs = np.concatenate(probs, axis=0)
    labels = np.concatenate(labels, axis=0)
    return probs, labels

if __name__ == "__main__":
    base_model_name = "resnet34_cifar10"
    assert base_model_name in config.PRETRAINED_MODELS, f"Model {base_model_name} not implemented."

    quant_model = establish_pretrained_model_quantile(base_model_name)
    
    distortion_list = []
    severity_list = []
    accuracy_list = []
    ECE_list = []
    calib_error_list = []
    accuracy_corrected_list = []
    ECE_corrected_list = []
    calib_error_corrected_list = []

    # Define the dataset
    base_train, base_valid, num_classes, data_transform = get_base_datasets(base_model_name)

    dataloader = DataLoader(base_valid, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)
    distortion_name = "valid"
    distortion_severity = 0

    # Get the probabilities
    # probs, labels = get_quantile_probabilities(quant_model, base_valid)
    # with open(f"./dump/{distortion_name}_{distortion_severity}_quant.pkl", "wb") as f:
    #     pickle.dump(probs, f)
    # with open(f"./dump/{distortion_name}_{distortion_severity}_quant_labels.pkl", "wb") as f:
    #     pickle.dump(labels, f)

    with open(f"./dump/{distortion_name}_{distortion_severity}_quant.pkl", "rb") as f:
        probs = pickle.load(f)
    with open(f"./dump/{distortion_name}_{distortion_severity}_quant_labels.pkl", "rb") as f:
        labels = pickle.load(f)

    acc = np.mean(np.argmax(probs, axis=1) == labels)
    ece = cal.get_ece_em(probs, labels, num_bins=5)
    calib_error = cal.get_calibration_error(probs, labels)

    calibrator = cal.PlattTopCalibrator(num_calibration=len(probs),num_bins=100)
    calibrator.train_calibration(probs, labels)

    probs_calibrated = calibrator.calibrate(probs)
    labels_calibrated = (labels==np.argmax(probs, axis=1))*1
    acc_calibrated = acc
    ece_calibrated = cal.get_ece_em(probs_calibrated, labels_calibrated, num_bins=5)
    calib_error_calibrated = cal.get_calibration_error(probs_calibrated, labels_calibrated)

    

    distortion_list.append(distortion_name)
    severity_list.append(distortion_severity)
    accuracy_list.append(acc)
    ECE_list.append(ece)
    calib_error_list.append(calib_error)
    accuracy_corrected_list.append(acc_calibrated)
    ECE_corrected_list.append(ece_calibrated)
    calib_error_corrected_list.append(calib_error_calibrated)


    for distortion_name in tqdm(distortions_paper):
        for distortion_severity in range(1, 6):
            dataset = DistortedCIFAR10(root = config.DATA_DIR, 
                                    distortion = distortion_name,
                                    severity = distortion_severity,
                                    transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)

            # Get the probabilities
            # probs, labels = get_quantile_probabilities(quant_model, dataset)
            # with open(f"./dump/{distortion_name}_{distortion_severity}_quant.pkl", "wb") as f:
            #     pickle.dump(probs, f)
            # with open(f"./dump/{distortion_name}_{distortion_severity}_quant_labels.pkl", "wb") as f:
            #     pickle.dump(labels, f)

            # Get the probabilities
            with open(f"./dump/{distortion_name}_{distortion_severity}_quant.pkl", "rb") as f:
                probs = pickle.load(f)
            with open(f"./dump/{distortion_name}_{distortion_severity}_quant_labels.pkl", "rb") as f:
                labels = pickle.load(f)

            acc = np.mean(np.argmax(probs, axis=1) == labels)
            ece = cal.get_ece_em(probs, labels, num_bins=5)
            calib_error = cal.get_calibration_error(probs, labels)

            probs_calibrated = calibrator.calibrate(probs)
            labels_calibrated = (labels==np.argmax(probs, axis=1))*1
            acc_calibrated = np.mean(np.argmax(probs, axis=1) == labels)
            ece_calibrated = cal.get_ece_em(probs_calibrated, labels_calibrated, num_bins=5)
            calib_error_calibrated = cal.get_calibration_error(probs_calibrated, labels_calibrated)

            distortion_list.append(distortion_name)
            severity_list.append(distortion_severity)
            accuracy_list.append(acc)
            ECE_list.append(ece)
            calib_error_list.append(calib_error)
            accuracy_corrected_list.append(acc_calibrated)
            ECE_corrected_list.append(ece_calibrated)
            calib_error_corrected_list.append(calib_error_calibrated)


    # Merge the lists into a Pandas DataFrame
    df = pd.DataFrame(
        list(zip(distortion_list, severity_list, accuracy_list, ECE_list, calib_error_list, accuracy_corrected_list, ECE_corrected_list, calib_error_corrected_list)),
        columns=["distortion", "severity", "accuracy", "ECE", "calib_error", "accuracy_corrected", "ECE_corrected", "calib_error_corrected"]
    )

    df.to_csv(config.RESULTS_DIR+f"calib_{base_model_name}2.csv", index=False)

    """
    --------
    Results of Baseline Network
    --------
    """

    baseline_model = establish_pretrained_model_baseline(base_model_name)

    
    distortion_list = []
    severity_list = []
    accuracy_list = []
    ECE_list = []
    calib_error_list = []
    accuracy_corrected_list = []
    ECE_corrected_list = []
    calib_error_corrected_list = []

    # Define the dataset
    base_train, base_valid, num_classes, data_transform = get_base_datasets(base_model_name)

    dataloader = DataLoader(base_valid, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)
    # Get the probabilities
    distortion_name = "valid"
    distortion_severity = 0

    # probs, labels = get_baseline_probabilities(baseline_model, base_valid)
    # with open(f"./dump/{distortion_name}_{distortion_severity}_baseline.pkl", "wb") as f:
    #     pickle.dump(probs, f)
    # with open(f"./dump/{distortion_name}_{distortion_severity}_baseline_labels.pkl", "wb") as f:
    #     pickle.dump(labels, f)

    with open(f"./dump/{distortion_name}_{distortion_severity}_baseline.pkl", "rb") as f:
        probs = pickle.load(f)
    with open(f"./dump/{distortion_name}_{distortion_severity}_baseline_labels.pkl", "rb") as f:
        labels = pickle.load(f)

    acc = np.mean(np.argmax(probs, axis=1) == labels)
    ece = cal.get_ece_em(probs, labels, num_bins=5)
    calib_error = cal.get_calibration_error(probs, labels)

    calibrator = cal.PlattTopCalibrator(len(probs), num_bins=100)
    calibrator.train_calibration(probs, labels)

    probs_calibrated = calibrator.calibrate(probs)
    labels_calibrated = (labels==np.argmax(probs, axis=1))*1
    acc_calibrated = acc
    ece_calibrated = cal.get_ece_em(probs_calibrated, labels_calibrated, num_bins=5)
    calib_error_calibrated = cal.get_calibration_error(probs_calibrated, labels_calibrated)

    distortion_list.append(distortion_name)
    severity_list.append(distortion_severity)
    accuracy_list.append(acc)
    ECE_list.append(ece)
    calib_error_list.append(calib_error)
    accuracy_corrected_list.append(acc_calibrated)
    ECE_corrected_list.append(ece_calibrated)
    calib_error_corrected_list.append(calib_error_calibrated)

    for distortion_name in tqdm(distortions_paper):
        for distortion_severity in range(1, 6):
            dataset = DistortedCIFAR10(root = config.DATA_DIR, 
                                    distortion = distortion_name,
                                    severity = distortion_severity,
                                    transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=config.BATCHSIZE, shuffle=False, num_workers=config.NUM_WORKERS)

            # Get the probabilities
            # probs, labels = get_baseline_probabilities(baseline_model, dataset)
            # with open(f"./dump/{distortion_name}_{distortion_severity}_baseline.pkl", "wb") as f:
            #     pickle.dump(probs, f)
            # with open(f"./dump/{distortion_name}_{distortion_severity}_baseline_labels.pkl", "wb") as f:
            #     pickle.dump(labels, f)

            # Get the probabilities
            with open(f"./dump/{distortion_name}_{distortion_severity}_baseline.pkl", "rb") as f:
                probs = pickle.load(f)
            with open(f"./dump/{distortion_name}_{distortion_severity}_baseline_labels.pkl", "rb") as f:
                labels = pickle.load(f)


            acc = np.mean(np.argmax(probs, axis=1) == labels)
            ece = cal.get_ece_em(probs, labels, num_bins=5)
            calib_error = cal.get_calibration_error(probs, labels)

            probs_calibrated = calibrator.calibrate(probs)
            labels_calibrated = (labels==np.argmax(probs, axis=1))*1
            acc_calibrated = acc
            ece_calibrated = cal.get_ece_em(probs_calibrated, labels_calibrated, num_bins=5)
            calib_error_calibrated = cal.get_calibration_error(probs_calibrated, labels_calibrated)

            distortion_list.append(distortion_name)
            severity_list.append(distortion_severity)
            accuracy_list.append(acc)
            ECE_list.append(ece)
            calib_error_list.append(calib_error)
            accuracy_corrected_list.append(acc_calibrated)
            ECE_corrected_list.append(ece_calibrated)
            calib_error_corrected_list.append(calib_error_calibrated)


    # Merge the lists into a Pandas DataFrame
    df = pd.DataFrame(
        list(zip(distortion_list, severity_list, accuracy_list, ECE_list, calib_error_list, accuracy_corrected_list, ECE_corrected_list, calib_error_corrected_list)),
        columns=["distortion", "severity", "accuracy", "ECE", "calib_error", "accuracy_corrected", "ECE_corrected", "calib_error_corrected"]
    )

    df.to_csv(config.RESULTS_DIR+f"calib_{base_model_name}_baseline2.csv", index=False)

    