from pathlib import Path
import json
import numpy as np
import random
from tqdm import tqdm
import itertools

import time
import datetime

from robustness.tools import breeds_helpers
import robustness.datasets

import sys
sys.path.append(str(Path(__name__).resolve().parent.parent.parent))

# pytorch utils
from base_utils.pytorch_utils.data_utils import get_dataset
from base_utils.pytorch_utils.model_utils import fetch_checkpoints_by_params, weights_update
from base_utils.pytorch_utils.inference import get_logits, get_sorted_softmax,confidence_histogram_with_accuracy, get_class_confidence_accuracy
from base_utils.pytorch_utils.prediction import MLPConformalPredictor

# conformal utils
from base_utils.conformal_utils.conformal import ConformalModel


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models





r"""All prediction methods and their configurations"""
# methods and configurations
regression_methods_configs = [
    {
        'method': 'confidence_histogram_regression',
        'model_args': {
            'num_features': 10,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0] / len(logits)
        }
    },
    {
        'method': 'confidence_histogram_less_last_bin_regression',
        'model_args': {
            'num_features': 9,
            'feature_extractor': lambda logits: confidence_histogram_with_accuracy(logits, bins=10)[0][:-1]  / len(logits)
        }
    },
    {
        'method': 'predicted_class_confidence_regression',
        'model_args': {
            'num_features': 1000,
            'feature_extractor': lambda logits: get_class_confidence_accuracy(logits, num_classes=1000, by_predictions=True, replace_nans=True)[0]
        }
    },
    {
        'method': 'average_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    },
    {
        # difference computed at runtime
        'method': 'difference_of_confidence_regression',
        'model_args': {
            'num_features': 1,
            'feature_extractor': lambda logits: get_sorted_softmax(logits.tensors[0])[0][:,0].mean()
        }
    }
]





if __name__ == "__main__":
    # random seed
    seed = 42

    # Set random states
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True

    cudnn.benchmark = True

    # pre-trained classifier
    arches = ['resnet18']

    # model checkpoints
    checkpoint_dir = "../../conformal_prediction/lightning_logs"

    # data dir
    info_dir = "../../BREEDS-Benchmarks/imagenet_class_hierarchy/modified"
    data_root_dir = "../../imagenet-testbed/s3_cache/datasets"
    synthetic_data_root = Path("../../imagenet-testbed/s3_cache/datasets/imagenet-c")

    # class_subsets
    use_classes_of_datasets = [None]
    breeds_datasets = ['entity13', 'entity30', 'living17', 'nonliving26']
    breeds_split = 'rand'

    # iter params
    params = list(itertools.product(arches, breeds_datasets, use_classes_of_datasets))

    # dataset utils params
    use_encodings = False

    # data params
    batch_size = 128
    num_workers = 4

    """hyperparameters for the MLP"""
    # model parameters
    num_bins = 10  # histogram bins are used for features
    hist_end_offset = 1

    # optimizer parameters
    max_iter = int(5e4)
    lr = 1e-4
    weight_decay = 1e-3
    momentum = 0.9

    # gpu
    gpu = 1


    """conformal parameters"""
    alphas = [0.2, 0.15, 0.1, 0.075, 0.05, 0.025, 0.01]

    randomized = True
    allow_zero_sets = True

    # APS
    kreg = -1
    lamda = 0


    for arch, breeds_dataset, use_classes_of_dataset in params:

        # get breeds pretrained checkpoint
        params_search_dict = {'arch': arch, 'breeds_dataset': breeds_dataset}
        breeds_checkpoints = fetch_checkpoints_by_params(checkpoint_dir, hparams_filename='hparams.yaml', checkpoint_filename_ext='ckpt', verbose=False, search_dict=params_search_dict)

        if len(breeds_checkpoints) != 1:
            raise RuntimeError(f"Number of pretrained checkpoints found for this config is expected to be 1, but got {len(breeds_checkpoints)}")


        # Get the model
        if use_encodings:
            model = models.__dict__[arch](pretrained=True).fc
        else:
            model = models.__dict__[arch](pretrained=True)

        num_classes = int(breeds_dataset[-2:])

        # Replace the linear layer. 
        # WARNING: Only works for Resnets. 
        # TODO: other arches.
        num_filters = model.fc.in_features
        model.fc = nn.Linear(num_filters, num_classes)


        # load pretrained checkpoint
        # model.load_state_dict(torch.load(breeds_checkpoints[0]))
        weights_update(model, torch.load(breeds_checkpoints[0]))


        if gpu is not None:
            model = model.cuda(gpu)


        logits_synt_all = []


        breeds_constructor = getattr(breeds_helpers, 'make_' + breeds_dataset)

        split_metadata = breeds_constructor(info_dir, split=breeds_split)
        train_subclasses, test_subclasses = split_metadata[1]

        dataset = robustness.datasets.CustomImageNet(data_root_dir, train_subclasses)

        _, val_loader = dataset.make_loaders(num_workers, batch_size, only_val=True)

        orig_val_samples = [s for s in val_loader.dataset.samples]


        val_logits = get_logits(
                model,
                val_loader,
                is_data_loader=True,
                num_classes=num_classes,
                batch_size=batch_size,
                num_workers=num_workers,
                gpu=gpu
            )

        synthetic_datasets = [s.name for s in synthetic_data_root.glob('imagenet-c*')]

        for cur_sel in tqdm(synthetic_datasets, desc="Extracting logits for synthetics"):
            cur_sel_root = synthetic_data_root / cur_sel

            # set synthetic loaders by modifying val_loader of BREEDS
            val_loader.dataset.root = str(cur_sel_root)
            val_loader.dataset.samples = [(str(cur_sel_root / s.split("val/")[-1]), t) for s, t in orig_val_samples]

            cur_logits = get_logits(
                model,
                val_loader,
                is_data_loader=True,
                num_classes=num_classes,
                batch_size=batch_size,
                num_workers=num_workers,
                gpu=gpu
            )
            logits_synt_all.append(cur_logits)


        for alpha in alphas:

            features_by_method = {s['method']: {'features': [], 'targets': []} for s in regression_methods_configs}

            conformal_model = ConformalModel(
                model,
                val_logits,
                is_logits = True,
                alpha = alpha, 
                randomized = randomized,
                allow_zero_sets = allow_zero_sets,
                kreg = kreg,
                lamda = lamda,
                batch_size = batch_size,
                num_workers = num_workers,
                gpu = gpu
            )

            
            for cur_method_config in regression_methods_configs:
                features_by_method[cur_method_config['method']]['features'].append(cur_method_config['model_args']['feature_extractor'](val_logits))
                features_by_method[cur_method_config['method']]['targets'].append(conformal_model.q_hat)


            for i, cur_sel in enumerate(tqdm(synthetic_datasets, desc="Calibrating on synthetics")):

                cur_logits = logits_synt_all[i]

                conformal_model = ConformalModel(
                    model,
                    cur_logits,
                    is_logits = True,
                    alpha = alpha, 
                    randomized = randomized,
                    allow_zero_sets = allow_zero_sets,
                    kreg = kreg,
                    lamda = lamda,
                    batch_size = batch_size,
                    num_workers = num_workers,
                    gpu = gpu
                )

                for cur_method_config in regression_methods_configs:
                    features_by_method[cur_method_config['method']]['features'].append(cur_method_config['model_args']['feature_extractor'](cur_logits))
                    features_by_method[cur_method_config['method']]['targets'].append(conformal_model.q_hat)



            r"""Fix DoC features by substracting AC values of ImageNet-Val from ImageNet-C datasets"""
            feat_offset = features_by_method['average_confidence_regression']['features'][0]
            target_offset = features_by_method['average_confidence_regression']['targets'][0]


            features_by_method['difference_of_confidence_regression']['features'] = [s - feat_offset for s in features_by_method['difference_of_confidence_regression']['features']]

            features_by_method['difference_of_confidence_regression']['targets'] = [s - target_offset for s in features_by_method['difference_of_confidence_regression']['targets']]


            for cur_method_config in regression_methods_configs:

                hparams = {
                    'classifier': arch,
                    'breeds_dataset': breeds_dataset,
                    'limit_classes_by_dataset': use_classes_of_dataset,
                    'alpha': alpha,
                    'randomized': randomized,
                    'allow_zero_sets': allow_zero_sets,
                    'kreg': kreg,
                    'lamda': lamda,
                    'lr': lr,
                    'weight_decay': weight_decay,
                    'momentum': momentum,
                    'max_iter': max_iter,
                    'batch_size': batch_size,
                    'num_workers': num_workers,
                    'seed': seed,
                    'gpu': gpu,
                }

                # same as the filename
                cur_method = cur_method_config['method']

                """set up experiment parameters"""
                output_base_dir = Path("./outputs").resolve()
                dt = datetime.datetime.fromtimestamp(time.time())
                output_dir = output_base_dir / cur_method / dt.strftime('%Y-%m-%d_%H:%M:%S')

                if not output_dir.exists():
                    output_dir.mkdir(parents=True)

                with open(output_dir / 'hparams.json', 'w') as fn:
                    json.dump(hparams, fn, indent=2)


                mlp_model = MLPConformalPredictor(**cur_method_config['model_args'], gpu=gpu)


                cur_feat = torch.stack(
                    features_by_method[cur_method]['features']
                ).cuda(gpu)
                cur_target = torch.Tensor(
                    features_by_method[cur_method]['targets']
                ).reshape(-1,1).cuda(gpu)

                if cur_feat.ndim < 2:
                    cur_feat = cur_feat.unsqueeze(-1)


                mlp_model.fit_model(
                    cur_feat,
                    cur_target,
                    lr=lr,
                    weight_decay=weight_decay,
                    momentum=momentum,
                    max_iter=max_iter,
                    gpu=gpu,
                    outdir=output_dir
                )


