import os, sys, inspect

from conformal import *
from utils import *
from multi_conformal import ensemble_logits, LACPredictionSets

# Import other standard packages
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import torch.backends.cudnn as cudnn
import random
from itertools import product
from pdb import set_trace


# Fix the random seed for reproducibility (you can change this, of course)
seed = 0
np.random.seed(seed=seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)

# Normalization from torchvision repo
transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

cudnn.benchmark = True
batch_size = 64

# Get the conformal calibration dataset
# imagenet_data = torchvision.datasets.ImageFolder('./imagenet_val/', transform)

# Initialize loaders
# loader = torch.utils.data.DataLoader(imagenet_data, batch_size=batch_size, shuffle=False, pin_memory=True)

# Get logits
# modelnames = [
#     'ResNeXt101', 'ResNet152', 'ResNet101', 'ResNet50',
#     'ResNet18', 'DenseNet161', 'VGG16', 'Inception', 'ShuffleNet'
# ]
modelnames = [
    'ResNeXt101', 'ResNet152', 'ResNet101', 'DenseNet161'
]
data_path = '/export/home/multi-conformal/imagenet/imagenet_val/'

logits = list()
for model in modelnames:
    l = get_logits_dataset(model, 'Imagenet', data_path)
    logits.append(l)
    print(f"Logits from model {model} loaded.")

# Ensemble
alpha = 0.1
c_vec = np.linspace(0, 1, 11)
base_weights = [0, 1, 3]
weights = product(base_weights, repeat=len(modelnames))

n_data_conf, n_data_val, n_data_val2 = 30000, 10000, 10000
bsz = 32
seed = 42
for weight in weights:
    c = np.array(weight)
    if np.sum(c) == 0:
        continue
    c = c / np.sum(c)
    print(f"c={[np.round(cc, 2) for cc in c]}")
    logits_mixed = ensemble_logits(logits, c)
    logits_mixed_cal, logits_mixed_val, logits_mixed_val2 = torch.utils.data.random_split(
        logits_mixed, [n_data_conf, n_data_val, n_data_val2], torch.Generator().manual_seed(seed),
    )
    top1_avg, top5_avg, cvg_avg, sz_avg = LACPredictionSets(logits_mixed_cal, logits_mixed_val, alpha, bsz)
    print(f"coverage_val={cvg_avg:.4f}, size_val={sz_avg:.4f}")
    top1_avg2, top5_avg2, cvg_avg2, sz_avg2 = LACPredictionSets(logits_mixed_cal, logits_mixed_val2, alpha, bsz)
    print(f"coverage_val2={cvg_avg2:.4f}, size_val2={sz_avg2:.4f}")

