from pathlib import Path
import json
import numpy as np
import random
import pandas as pd
from tqdm import tqdm
import itertools

from typing import Union

from robustness.tools import breeds_helpers
import robustness.datasets

import sys
sys.path.append(str(Path(__name__).resolve().parent.parent.parent))

# data utils
from base_utils.pytorch_utils.data_utils import get_dataset
from base_utils.pytorch_utils.model_utils import fetch_checkpoints_by_params, weights_update
from base_utils.pytorch_utils.inference import get_sorted_softmax, get_logits, confidence_histogram_with_accuracy, get_class_confidence_accuracy
from base_utils.pytorch_utils.prediction import MLPConformalPredictor, QuantileConfidenceThreshold

# conformal utils
from base_utils.conformal_utils.conformal import ConformalModel


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models


r"""All prediction methods and their configurations"""
# methods and configurations
regression_methods_configs = [
    {
        'method': 'confidence_histogram_regression',
        'model_args': {
            'num_features': 10,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0] / len(logits)
        }
    },
    {
        'method': 'confidence_histogram_less_last_bin_regression',
        'model_args': {
            'num_features': 9,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0][:-1]  / len(logits)
        }
    },
    {
        'method': 'predicted_class_confidence_regression',
        'model_args': {
            'num_features': 1000,
            'feature_extractor': lambda logits: get_class_confidence_accuracy(logits, num_classes=1000, by_predictions=True, replace_nans=True)[0]
        }
    },
    {
        'method': 'average_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    },
    {
        'method': 'difference_of_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits, offset: get_sorted_softmax(logits.tensors[0])[0][:,0].mean() - offset
        }
    }
]




def fetch_regression_checkpoints_by_params(checkpoint_dir: Union[str, Path], hparams_filename: str = 'hparams.json', checkpoint_filename_ext: str = 'pth.tar', verbose: bool = False, search_dict: dict = None, **kwargs):
    r"""Deprecated. Use `model_utils.fetch_checkpoints_by_params` instead."""

    checkpoint_dir = Path(checkpoint_dir)

    if search_dict is not None:
        assert isinstance(search_dict, dict), f"A dictionary should be provided for `search_dict` but {type(search_dict)} type was given"
    else:
        assert kwargs, "Either search_dict must be set or target parameters should be passed as positional arguments."
        search_dict = {}

    search_dict.update(kwargs)

    if verbose:
        print("Searching for the following parameters...")
        print(search_dict.items())


    match_list = []
    for cur_dir in checkpoint_dir.iterdir():

        try:
            with open(cur_dir / hparams_filename) as fn:
                hparams = json.load(fn)
            
            if search_dict.items() <= hparams.items():
                checkpoint_files = cur_dir.glob(f"*.{checkpoint_filename_ext}")
                match_list.extend(checkpoint_files)

        except FileNotFoundError as e:
            pass

    return match_list
            


    




if __name__ == "__main__":

    # same as the filename
    cache_fname = './cache/breeds_tau_predictions.csv'

    # random seed
    seed = 42

    # Set random states
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True

    cudnn.benchmark = True

    # dataset paths
    val_dataset_name = 'ImageNet-Val'
    datasets = ['ImageNetV2', 'ImageNet-R', 'ImageNet-Sketch']

    # dataset hparams
    use_classes_of_datasets = [None]
    breeds_datasets = ['entity13', 'entity30', 'living17', 'nonliving26']
    breeds_split = 'rand'

    # pre-trained classifier
    arches = ['resnet18']

    # checkpoints
    classifier_checkpoint_dir = "../../conformal_prediction/lightning_logs"
    regression_checkpoint_dir = Path('./outputs')

    # data dir
    info_dir = "../../BREEDS-Benchmarks/imagenet_class_hierarchy/modified"
    data_root_dir = "../../imagenet-testbed/s3_cache/datasets"
    
    # dataset utils params
    use_encodings = False

    # data params
    batch_size = 128
    num_workers = 4

    # gpu
    gpu = 1


    """conformal parameters"""
    alphas = [0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.01]

    randomized = True
    allow_zero_sets = True

    # APS
    kreg = -1
    lamda = 0


    try:
        df = pd.read_csv(cache_fname)
        print("[INFO] Using cache file for the conformal thresholds...")

    except FileNotFoundError as e:
        ### Perform the experiment
        df = pd.DataFrame(columns = ["Dataset","Method","Classifier","ConformalPredictor","Top1","Top5","alpha","calibrated_tau","original_coverage","original_size","oracle_tau","predicted_tau","predicted_coverage","predicted_size"])

        for arch, breeds_dataset, use_classes_of_dataset in itertools.product(arches, breeds_datasets, use_classes_of_datasets):

            # get breeds pretrained checkpoint
            params_search_dict = {'arch': arch, 'breeds_dataset': breeds_dataset}
            breeds_checkpoints = fetch_checkpoints_by_params(classifier_checkpoint_dir, hparams_filename='hparams.yaml', checkpoint_filename_ext='ckpt', verbose=False, search_dict=params_search_dict)

            if len(breeds_checkpoints) != 1:
                raise RuntimeError(f"Number of pretrained checkpoints found for this config is expected to be 1, but got {len(breeds_checkpoints)}")

            # get pretrained model
            # taken from '/root/dockspace/conformal_prediction/conformal_experiments/aps_raps.py'
            # Get the model
            if use_encodings:
                model = models.__dict__[arch](pretrained=True).fc
            else:
                model = models.__dict__[arch](pretrained=True)

            num_classes = int(breeds_dataset[-2:])

            # Replace the linear layer. 
            # WARNING: Only works for Resnets. 
            # TODO: other arches.
            num_filters = model.fc.in_features
            model.fc = nn.Linear(num_filters, num_classes)

            # load pretrained checkpoint
            # model.load_state_dict(torch.load(breeds_checkpoints[0]))
            weights_update(model, torch.load(breeds_checkpoints[0]))

            if gpu is not None:
                model = model.cuda(gpu)

            
            breeds_constructor = getattr(breeds_helpers, 'make_' + breeds_dataset)

            split_metadata = breeds_constructor(info_dir, split=breeds_split)
            train_subclasses, test_subclasses = split_metadata[1]

            dataset = robustness.datasets.CustomImageNet(data_root_dir, train_subclasses)

            _, val_loader = dataset.make_loaders(num_workers, batch_size, only_val=True)

            val_logits = get_logits(
                model,
                val_loader,
                is_data_loader=True,
                num_classes=num_classes,
                batch_size=batch_size,
                num_workers=num_workers,
                gpu=gpu
            )


            for cur_dataset_name in tqdm([breeds_dataset], desc='Iter. datasets'):

                dataset = robustness.datasets.CustomImageNet(data_root_dir, test_subclasses)

                _, cur_loader = dataset.make_loaders(num_workers, batch_size, only_val=True)

                cur_logits = get_logits(
                    model,
                    cur_loader,
                    is_data_loader=True,
                    num_classes=num_classes,
                    batch_size=batch_size,
                    num_workers=num_workers,
                    gpu=gpu
                )


                for alpha in alphas:

                    conformal_model = ConformalModel(
                        model,
                        val_logits,
                        is_logits = True,
                        alpha = alpha, 
                        randomized = randomized,
                        allow_zero_sets = allow_zero_sets,
                        kreg = kreg,
                        lamda = lamda,
                        batch_size = batch_size,
                        num_workers = num_workers,
                        gpu = gpu
                    )
                    oracle_model = ConformalModel(
                        model,
                        cur_logits,
                        is_logits = True,
                        alpha = alpha, 
                        randomized = randomized,
                        allow_zero_sets = allow_zero_sets,
                        kreg = kreg,
                        lamda = lamda,
                        batch_size = batch_size,
                        num_workers = num_workers,
                        gpu = gpu
                    )

                    orig_calib_tau = conformal_model.q_hat
                    oracle_tau = oracle_model.q_hat

                    orig_top1, orig_top5, orig_cvg, orig_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)

                    f"""search regression methods with existing checkpoint"""
                    hparam_search_dict = {'classifier': arch, 'breeds_dataset': cur_dataset_name, 'alpha': alpha}
                    if use_classes_of_dataset:
                        hparam_search_dict.update({'limit_classes_by_dataset': use_classes_of_dataset.lower()})
                    matching_regression_configs = []
                    for cur_method_config in regression_methods_configs:
                        match_checkpoints = fetch_regression_checkpoints_by_params(
                            regression_checkpoint_dir / cur_method_config['method'],
                            verbose=True,
                            search_dict=hparam_search_dict,
                        )
                        if match_checkpoints:
                            match_config = {k: v for k,v in cur_method_config.items()}
                            match_config.update({'checkpoint_file': match_checkpoints[-1]})
                            matching_regression_configs.append(match_config)
           
                    if not matching_regression_configs:
                        print(f'[INFO] No matching regression method checkpoints were provided for\t---\t classifier: {arch} | alpha: {alpha}')

                    for cur_method_config in matching_regression_configs:

                        # set feature extractor of DoC method
                        if cur_method_config['method'] == "difference_of_confidence_regression":
                            offset = get_sorted_softmax(val_logits.tensors[0])[0][:,0].mean()
                            cur_method_config['model_args']['feature_extractor']= lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean() - offset

                        mlp_model = MLPConformalPredictor(**cur_method_config['model_args'], gpu=gpu)
                        mlp_model.load_from_checkpoint(checkpoint_file=cur_method_config['checkpoint_file'])

                        predicted_tau = mlp_model.predict_by_logits(cur_logits)

                        # account for the offset of DoC method
                        if cur_method_config['method'] == "difference_of_confidence_regression":
                            predicted_tau += orig_calib_tau

                        # set tau of conformal model to predicted tau in order to get coverage results
                        conformal_model.q_hat = predicted_tau

                        cur_top1, cur_top5, cur_cvg, cur_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)


                        df = df.append({"Dataset": cur_dataset_name,
                                        "on_classes_of": use_classes_of_dataset,
                                        "Method": cur_method_config['method'],
                                        "Classifier": arch,
                                        "ConformalPredictor": "APS",
                                        "Top1": np.round(orig_top1,3),
                                        "Top5": np.round(orig_top5,3),
                                        "alpha": alpha,
                                        "calibrated_tau": orig_calib_tau,
                                        "original_coverage": np.round(orig_cvg,3),
                                        "original_size": np.round(orig_sz,3),
                                        "oracle_tau": oracle_tau,
                                        "predicted_tau": predicted_tau,
                                        "predicted_coverage": np.round(cur_cvg,3),
                                        "predicted_size": np.round(cur_sz,3)}, ignore_index=True)


                        # save the current progress
                        df.to_csv(cache_fname)


                    # Quantile Thresholded Confidence
                    cur_qtc = QuantileConfidenceThreshold(val_logits.tensors[0], val_logits.tensors[1], quantile=orig_calib_tau)

                    predicted_tau = cur_qtc.predict_error(cur_logits.tensors[0])

                    # set tau of conformal model to predicted tau in order to get coverage results
                    conformal_model.q_hat = predicted_tau

                    cur_top1, cur_top5, cur_cvg, cur_sz = conformal_model.validate_conformal(cur_logits, is_logits=True)

                    df = df.append({"Dataset": cur_dataset_name,
                                    "on_classes_of": use_classes_of_dataset,
                                    "Method": 'QTC',
                                    "Classifier": arch,
                                    "ConformalPredictor": "APS",
                                    "Top1": np.round(orig_top1,3),
                                    "Top5": np.round(orig_top5,3),
                                    "alpha": alpha,
                                    "calibrated_tau": orig_calib_tau,
                                    "original_coverage": np.round(orig_cvg,3),
                                    "original_size": np.round(orig_sz,3),
                                    "oracle_tau": oracle_tau,
                                    "predicted_tau": predicted_tau,
                                    "predicted_coverage": np.round(cur_cvg,3),
                                    "predicted_size": np.round(cur_sz,3)}, ignore_index=True)


                    df.to_csv(cache_fname)





