from pathlib import Path
import json
import numpy as np
import random
import pandas as pd
from tqdm import tqdm
import itertools

from typing import Union

import sys
sys.path.append(str(Path(__name__).resolve().parent.parent.parent))

# data utils
from base_utils.pytorch_utils.model_utils import fetch_checkpoints_by_params
from base_utils.pytorch_utils.data_utils import get_dataset
from base_utils.pytorch_utils.inference import get_sorted_softmax, get_logits, confidence_histogram_with_accuracy, get_class_confidence_accuracy
from base_utils.pytorch_utils.prediction import MLPConformalPredictor, QuantileConfidenceThreshold

# conformal utils
from base_utils.conformal_utils.conformal import ConformalModel


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models


r"""All prediction methods and their configurations"""
# methods and configurations
regression_methods_configs = [
    {
        'method': 'confidence_histogram_regression',
        'model_args': {
            'num_features': 10,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0] / len(logits)
        }
    },
    {
        'method': 'confidence_histogram_less_last_bin_regression',
        'model_args': {
            'num_features': 9,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0][:-1]  / len(logits)
        }
    },
    {
        'method': 'predicted_class_confidence_regression',
        'model_args': {
            'num_features': 1000,
            'feature_extractor': lambda logits: get_class_confidence_accuracy(logits, num_classes=1000, by_predictions=True, replace_nans=True)[0]
        }
    },
    {
        'method': 'average_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    },
    {
        'method': 'difference_of_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits, offset: get_sorted_softmax(logits.tensors[0])[0][:,0].mean() - offset
        }
    }
]




if __name__ == "__main__":

    # same as the filename
    cache_fname = './cache/natural_datasets_raps_predictions.csv'
    checkpoint_dir = Path('./outputs')

    # random seed
    seed = 42

    # Set random states
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True

    cudnn.benchmark = True

    # dataset paths
    val_dataset_name = 'ImageNet-Val'
    datasets = ['ImageNetV2', 'ImageNet-R', 'ImageNet-Sketch']

    # dataset hparams
    # use_classes_of_datasets = [None, 'ImageNet-R']
    use_classes_of_datasets = [None]

    # pre-trained classifier
    # arches = ['resnet50', 'densenet121']
    arches = ['resnet50']
    
    # dataset utils params
    use_encodings = False

    # data params
    batch_size = 128
    num_workers = 4

    # gpu
    gpu = 0


    """conformal parameters"""
    alphas = [0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.01]

    randomized = True
    allow_zero_sets = True

    # APS
    kregs = [2, 4]
    lamdas = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
    
    # platt scaling
    # platt_scaling_uses = [False, True]
    platt_scaling_uses = [False]


    try:
        df = pd.read_csv(cache_fname)
        print("[INFO] Using cache file for the conformal thresholds...")

    except FileNotFoundError as e:
        ### Perform the experiment
        df = pd.DataFrame(columns = ["Dataset","Method","Classifier","ConformalPredictor","Top1","Top5","alpha","raps_kreg","raps_lambda","use_platt_scaling","calibrated_tau","original_coverage","original_size","oracle_tau","predicted_tau","predicted_coverage","predicted_size"])

        for arch, use_classes_of_dataset in itertools.product(arches, use_classes_of_datasets):

            # get pretrained model
            if use_encodings:
                model = models.__dict__[arch](pretrained=True).fc
            else:
                model = models.__dict__[arch](pretrained=True)

            if gpu is not None:
                model = model.cuda(gpu)

            
            val_dataset = get_dataset(val_dataset_name, dataset_type='raw', limit_classes_by_dataset=use_classes_of_dataset)

            val_logits = get_logits(
                model,
                val_dataset,
                batch_size=batch_size, 
                num_workers=num_workers, 
                gpu=gpu
            )


            for cur_dataset_name in tqdm(datasets, desc='Iter. datasets'):

                cur_dataset = get_dataset(cur_dataset_name, dataset_type='raw', limit_classes_by_dataset=use_classes_of_dataset)

                cur_logits = get_logits(
                    model,
                    cur_dataset,
                    batch_size=batch_size, 
                    num_workers=num_workers, 
                    gpu=gpu
                )


                for alpha, kreg, lamda, use_platt_scaling in itertools.product(alphas, kregs, lamdas, platt_scaling_uses):

                    conformal_model = ConformalModel(
                        model,
                        val_logits,
                        is_logits = True,
                        alpha = alpha, 
                        randomized = randomized,
                        allow_zero_sets = allow_zero_sets,
                        use_platt_scaling = use_platt_scaling,
                        kreg = kreg,
                        lamda = lamda,
                        batch_size = batch_size,
                        num_workers = num_workers,
                        gpu = gpu
                    )
                    oracle_model = ConformalModel(
                        model,
                        cur_logits,
                        is_logits = True,
                        alpha = alpha, 
                        randomized = randomized,
                        allow_zero_sets = allow_zero_sets,
                        use_platt_scaling = use_platt_scaling,
                        kreg = kreg,
                        lamda = lamda,
                        batch_size = batch_size,
                        num_workers = num_workers,
                        gpu = gpu
                    )

                    orig_calib_tau = conformal_model.q_hat
                    oracle_tau = oracle_model.q_hat

                    orig_top1, orig_top5, orig_cvg, orig_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)

                    """search regression methods with existing checkpoint"""
                    hparam_search_dict = {'classifier': arch, 'alpha': alpha}
                    if use_classes_of_dataset:
                        hparam_search_dict.update({'limit_classes_by_dataset': use_classes_of_dataset.lower()})
                    if use_platt_scaling is True:
                        hparam_search_dict.update({'use_platt_scaling': use_platt_scaling})
                    if lamda != 0:
                        hparam_search_dict.update({'raps_kreg': kreg, 'raps_lambda': lamda})
                    matching_regression_configs = []
                    for cur_method_config in regression_methods_configs:
                        match_checkpoints = fetch_checkpoints_by_params(
                            checkpoint_dir / cur_method_config['method'],
                            verbose=True,
                            search_dict=hparam_search_dict,
                            exclude_dict={'qtc_boosted': True}
                        )
                        if match_checkpoints:
                            match_config = {k: v for k,v in cur_method_config.items()}
                            match_config.update({'checkpoint_file': match_checkpoints[-1]})
                            matching_regression_configs.append(match_config)
           
                    if not matching_regression_configs:
                        print(f'[INFO] No matching regression method checkpoints were provided for\t---\t classifier: {arch} | alpha: {alpha}')

                    for cur_method_config in matching_regression_configs:

                        # set feature extractor of DoC method
                        if cur_method_config['method'] == "difference_of_confidence_regression":
                            offset = get_sorted_softmax(val_logits.tensors[0])[0][:,0].mean()
                            cur_method_config['model_args']['feature_extractor']= lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean() - offset

                        mlp_model = MLPConformalPredictor(**cur_method_config['model_args'], gpu=gpu)
                        mlp_model.load_from_checkpoint(checkpoint_file=cur_method_config['checkpoint_file'])

                        predicted_tau = mlp_model.predict_by_logits(cur_logits)

                        # account for the offset of DoC method
                        if cur_method_config['method'] == "difference_of_confidence_regression":
                            predicted_tau += orig_calib_tau

                        # set tau of conformal model to predicted tau in order to get coverage results
                        conformal_model.q_hat = predicted_tau

                        cur_top1, cur_top5, cur_cvg, cur_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)


                        df = df.append({"Dataset": cur_dataset_name,
                                        "on_classes_of": use_classes_of_dataset,
                                        "Method": cur_method_config['method'],
                                        "Classifier": arch,
                                        "ConformalPredictor": "RAPS",
                                        "Top1": np.round(orig_top1,3),
                                        "Top5": np.round(orig_top5,3),
                                        "alpha": alpha,
                                        "raps_kreg": kreg,
                                        "raps_lambda": lamda,
                                        "use_platt_scaling": use_platt_scaling,
                                        "calibrated_tau": orig_calib_tau,
                                        "original_coverage": np.round(orig_cvg,3),
                                        "original_size": np.round(orig_sz,3),
                                        "oracle_tau": oracle_tau,
                                        "predicted_tau": predicted_tau,
                                        "predicted_coverage": np.round(cur_cvg,3),
                                        "predicted_size": np.round(cur_sz,3)}, ignore_index=True)


                        # save the current progress
                        df.to_csv(cache_fname)


                    # Quantile Thresholded Confidence
                    total_qtc_reg = max(0, (val_logits.tensors[0].shape[1] - (kreg-1))) * lamda

                    cur_qtc = QuantileConfidenceThreshold(val_logits.tensors[0], val_logits.tensors[1], quantile=orig_calib_tau/total_qtc_reg, softmax_temperature=conformal_model.softmax_temp)

                    predicted_tau = cur_qtc.predict(cur_logits.tensors[0]) * total_qtc_reg

                    # set tau of conformal model to predicted tau in order to get coverage results
                    conformal_model.q_hat = predicted_tau

                    cur_top1, cur_top5, cur_cvg, cur_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)

                    df = df.append({"Dataset": cur_dataset_name,
                                    "on_classes_of": use_classes_of_dataset,
                                    "Method": 'QTC',
                                    "Classifier": arch,
                                    "ConformalPredictor": "RAPS",
                                    "Top1": np.round(orig_top1,3),
                                    "Top5": np.round(orig_top5,3),
                                    "alpha": alpha,
                                    "raps_kreg": kreg,
                                    "raps_lambda": lamda,
                                    "use_platt_scaling": use_platt_scaling,
                                    "calibrated_tau": orig_calib_tau,
                                    "original_coverage": np.round(orig_cvg,3),
                                    "original_size": np.round(orig_sz,3),
                                    "oracle_tau": oracle_tau,
                                    "predicted_tau": predicted_tau,
                                    "predicted_coverage": np.round(cur_cvg,3),
                                    "predicted_size": np.round(cur_sz,3)}, ignore_index=True)


                    df.to_csv(cache_fname)





