from pathlib import Path
import json
import numpy as np
import random
from tqdm import tqdm
import itertools

import time
import datetime

import sys
sys.path.append(str(Path(__name__).resolve().parent.parent.parent))

# data utils
from base_utils.pytorch_utils.data_utils import get_dataset
from base_utils.pytorch_utils.inference import get_logits, get_sorted_softmax,confidence_histogram_with_accuracy, get_class_confidence_accuracy
from base_utils.pytorch_utils.prediction import MLPConformalPredictor

# conformal utils
from base_utils.conformal_utils.conformal import ConformalModel


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models





r"""All prediction methods and their configurations"""
# methods and configurations
regression_methods_configs = [
    {
        'method': 'confidence_histogram_regression',
        'model_args': {
            'num_features': 10,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0] / len(logits)
        }
    },
    {
        'method': 'confidence_histogram_less_last_bin_regression',
        'model_args': {
            'num_features': 9,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0][:-1]  / len(logits)
        }
    },
    {
        'method': 'predicted_class_confidence_regression',
        'model_args': {
            'num_features': 1000,
            'feature_extractor': lambda logits: get_class_confidence_accuracy(logits, num_classes=1000, by_predictions=True, replace_nans=True)[0]
        }
    },
    {
        'method': 'average_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    },
    {
        # difference computed at runtime
        'method': 'difference_of_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    }
]





if __name__ == "__main__":
    # random seed
    seed = 42

    # Set random states
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True

    cudnn.benchmark = True

    # pre-trained classifier
    # arches = ['resnet50', 'densenet121']
    arches = ['resnet50']

    # class_subsets
    use_classes_of_datasets = [None]

    # iter params
    params = list(itertools.product(arches, use_classes_of_datasets))

    # dataset utils params
    use_encodings = False

    # data params
    batch_size = 128
    num_workers = 4

    """hyperparameters for the MLP"""
    # model parameters
    num_bins = 10  # histogram bins are used for features
    hist_end_offset = 1

    # optimizer parameters
    max_iter = int(5e4)
    lr = 1e-4
    weight_decay = 1e-3
    momentum = 0.9

    # gpu
    gpu = 0


    """conformal parameters"""
    alphas = [0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.01]

    randomized = True
    allow_zero_sets = True

    # APS
    kregs = [2, 4]
    lamdas = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]

    # platt scaling
    # platt_scaling_uses = [False, True]
    platt_scaling_uses = [False]


    for arch, use_classes_of_dataset in params:

        # get pretrained model
        if use_encodings:
            model = models.__dict__[arch](pretrained=True).fc
        else:
            model = models.__dict__[arch](pretrained=True)

        if gpu is not None:
            model = model.cuda(gpu)


        synthetic_datasets = [s.stem for s in Path(f"/root/dockspace/datasets/logits/{arch}/imagenet-c").glob('imagenet-c*')]
        orig_dataset = 'imagenet-val'


        val_logits_numpy = get_dataset(orig_dataset, dataset_type='logits', arch=arch, limit_classes_by_dataset=use_classes_of_dataset)


        val_logits = torch.utils.data.TensorDataset(torch.from_numpy(val_logits_numpy.samples), torch.from_numpy(val_logits_numpy.targets))

        for alpha, kreg, lamda, use_platt_scaling in itertools.product(alphas, kregs, lamdas, platt_scaling_uses):

            features_by_method = {s['method']: {'features': [], 'targets': []} for s in regression_methods_configs}

            conformal_model = ConformalModel(
                model,
                val_logits,
                is_logits = True,
                alpha = alpha, 
                randomized = randomized,
                allow_zero_sets = allow_zero_sets,
                use_platt_scaling = use_platt_scaling,
                kreg = kreg,
                lamda = lamda,
                batch_size = batch_size,
                num_workers = num_workers,
                gpu = gpu
            )

            
            for cur_method_config in regression_methods_configs:
                features_by_method[cur_method_config['method']]['features'].append(cur_method_config['model_args']['feature_extractor'](val_logits))
                features_by_method[cur_method_config['method']]['targets'].append(conformal_model.q_hat)


            for cur_sel in tqdm(synthetic_datasets, desc="Calibrating on synthetics"):

                cur_logits_numpy = get_dataset(cur_sel, dataset_type='logits', arch=arch, limit_classes_by_dataset=use_classes_of_dataset)


                cur_logits = torch.utils.data.TensorDataset(torch.from_numpy(cur_logits_numpy.samples), torch.from_numpy(cur_logits_numpy.targets))

                conformal_model = ConformalModel(
                    model,
                    cur_logits,
                    is_logits = True,
                    alpha = alpha, 
                    randomized = randomized,
                    allow_zero_sets = allow_zero_sets,
                    use_platt_scaling = use_platt_scaling,
                    kreg = kreg,
                    lamda = lamda,
                    batch_size = batch_size,
                    num_workers = num_workers,
                    gpu = gpu
                )

                for cur_method_config in regression_methods_configs:
                    features_by_method[cur_method_config['method']]['features'].append(cur_method_config['model_args']['feature_extractor'](cur_logits))
                    features_by_method[cur_method_config['method']]['targets'].append(conformal_model.q_hat)



            r"""Fix DoC features by substracting AC values of ImageNet-Val from ImageNet-C datasets"""
            feat_offset = features_by_method['average_confidence_regression']['features'][0]
            target_offset = features_by_method['average_confidence_regression']['targets'][0]


            features_by_method['difference_of_confidence_regression']['features'] = [s - feat_offset for s in features_by_method['difference_of_confidence_regression']['features']]

            features_by_method['difference_of_confidence_regression']['targets'] = [s - target_offset for s in features_by_method['difference_of_confidence_regression']['targets']]


            for cur_method_config in regression_methods_configs:

                hparams = {
                    'classifier': arch,
                    'limit_classes_by_dataset': use_classes_of_dataset,
                    'alpha': alpha,
                    'randomized': randomized,
                    'allow_zero_sets': allow_zero_sets,
                    'kreg': kreg,
                    'lamda': lamda,
                    'lr': lr,
                    'weight_decay': weight_decay,
                    'momentum': momentum,
                    'max_iter': max_iter,
                    'batch_size': batch_size,
                    'num_workers': num_workers,
                    'seed': seed,
                    'gpu': gpu,
                }
                if lamda > 0:
                    hparams.update({'raps_kreg': kreg, 'raps_lambda': lamda})
                if use_platt_scaling:
                    hparams.update({'use_platt_scaling': use_platt_scaling})

                # same as the filename
                cur_method = cur_method_config['method']

                """set up experiment parameters"""
                output_base_dir = Path("./outputs").resolve()
                dt = datetime.datetime.fromtimestamp(time.time())
                output_dir = output_base_dir / cur_method / dt.strftime('%Y-%m-%d_%H:%M:%S')

                if not output_dir.exists():
                    output_dir.mkdir(parents=True)

                with open(output_dir / 'hparams.json', 'w') as fn:
                    json.dump(hparams, fn, indent=2)


                mlp_model = MLPConformalPredictor(**cur_method_config['model_args'], gpu=gpu)


                cur_feat = torch.stack(
                    features_by_method[cur_method]['features']
                ).cuda(gpu)
                cur_target = torch.Tensor(
                    features_by_method[cur_method]['targets']
                ).reshape(-1,1).cuda(gpu)

                if cur_feat.ndim < 2:
                    cur_feat = cur_feat.unsqueeze(-1)


                mlp_model.fit_model(
                    cur_feat,
                    cur_target,
                    lr=lr,
                    weight_decay=weight_decay,
                    momentum=momentum,
                    max_iter=max_iter,
                    gpu=gpu,
                    outdir=output_dir
                )


