import argparse
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from sklearn.linear_model import LinearRegression
import pickle
import keras
import random
import os
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input
from PIL import Image
import sys

class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    """
    def __init__(self, n_bins=20):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

def tune_temp(logits, labels, binary_search=True, lower=0.2, upper=10.0, eps=0.0001):

    if binary_search:
        import torch
        import torch.nn.functional as F

        logits = torch.FloatTensor(logits)
        labels = torch.LongTensor(labels)
        t_guess = torch.FloatTensor([0.5*(lower + upper)]).requires_grad_()

        while upper - lower > eps:
            if torch.autograd.grad(F.cross_entropy(logits / t_guess, labels), t_guess)[0] > 0:
                upper = 0.5 * (lower + upper)
            else:
                lower = 0.5 * (lower + upper)
            t_guess = t_guess * 0 + 0.5 * (lower + upper)

        t = min([lower, 0.5 * (lower + upper), upper], key=lambda x: float(F.cross_entropy(logits / x, labels)))
    else:
        import cvxpy as cx
        set_size = logits.shape[0]

        t = cx.Variable()

        expr = sum((cx.Minimize(cx.log_sum_exp(logits[i, :] * t) - logits[i, labels[i]] * t)
                    for i in range(set_size)))
        p = cx.Problem(expr, [lower <= t, t <= upper])

        p.solve()
        t = 1 / t.value
    return torch.tensor([t]).cuda()


# Set up hyperparameters
NUM_BINS = 20
ece_criterion = _ECELoss(n_bins=NUM_BINS).cuda()

# Set seed
SEED = 1
print('SEED: ', SEED)
np.random.seed(SEED)

ARCH_TYPE = sys.argv[1]         # 'vit', 'resnet50', 'clip'
LEVEL = int(sys.argv[2])        # level of the class hierarchy for BREEDS
NUM_CLASSES = int(sys.argv[3])  # number of subclasses in a domain

main_dir = ''   
data_dir = os.path.join(main_dir, f'MDTS_features/{ARCH_TYPE}/{LEVEL}_{NUM_CLASSES}')
save_dir = os.path.join(main_dir, f"MDTS_regressor/{LEVEL}_{NUM_CLASSES}")

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Load data, if already saved
# Z: representation features
# Y: labels
# G: group labels
# Logits: logit outputs
Z_ind = np.load(os.path.join(data_dir, 'Z_ind.npy'))

print(Z_ind.shape)
Y_ind = np.load(os.path.join(data_dir, 'G_ind.npy'))
print(Y_ind.shape)
G_ind = np.load(os.path.join(data_dir, 'G_ind.npy'))
print(G_ind.shape)
Logits_ind = np.load(os.path.join(data_dir, 'Logits_ind.npy'))
print(Logits_ind.shape)



# Visualize data information
print('in-distribution group: ', np.unique(G_ind))
# print('out-of-distribution group: ', np.unique(G_ood))
# print('Logits_ood shape: ', Logits_ood.shape)
print('Logits_ind shape: ', Logits_ind.shape)

# Partition data InD domains into two parts
permute_index = np.random.permutation(np.arange(Z_ind.shape[0]))
calibrate_size = int(Z_ind.shape[0] * 0.5)
calibrate_index = permute_index[:calibrate_size]
eval_index = np.setdiff1d(permute_index, calibrate_index)

# Part1 for calibration
Z_ind_calibrate = Z_ind[calibrate_index]
G_ind_calibrate = G_ind[calibrate_index]
Y_ind_calibrate = Y_ind[calibrate_index]
Logits_ind_calibrate = Logits_ind[calibrate_index]

# Part1 for evaluation
Z_ind_eval = Z_ind[eval_index]
G_ind_eval = G_ind[eval_index]
Y_ind_eval = Y_ind[eval_index]
Logits_ind_eval = Logits_ind[eval_index]

print('======== running MD-TS =========')
Temperature_md = []
Z_md = []
for group_ in np.unique(G_ind_calibrate):
    print("Z_ind shape: {}".format(Z_ind_calibrate.shape))
    print("G_ind shape: {}".format(G_ind_calibrate.shape))
    print("Y_ind shape: {}".format(Y_ind_calibrate.shape))
    print("Logits_ind shape: {}".format(Logits_ind_calibrate.shape))
    index_group = np.where(G_ind_calibrate == group_)[0]
    print('============group={}, number of samples={}============'.format(group_, index_group.shape[0]))
    logits_ = torch.from_numpy(Logits_ind_calibrate[index_group]).cuda()
    labels_ = torch.from_numpy(Y_ind_calibrate[index_group]).cuda()
    print(labels_)
    temperature = tune_temp(logits_.cpu().numpy(),
                            labels_.cpu().numpy(),
                            binary_search=True,
                            lower=0.1, upper=20.0, eps=0.0001)

    temperature_vec = torch.ones_like(labels_) * temperature.item()
    Temperature_md.append(temperature_vec.cpu().numpy())
    Z_md.append(Z_ind_calibrate[index_group])

# Learn a linear model
Temperature_md = np.concatenate(Temperature_md, axis=0)
Z_md = np.concatenate(Z_md, axis=0)
model_LR = LinearRegression().fit(Z_md, Temperature_md)
print('==================finished MD-TS training==================')
with open(os.path.join(save_dir, f'MDTS_linear_regressor_{ARCH_TYPE}.pkl'),'wb') as f:
    pickle.dump(model_LR,f)

# Ind evaluation
temperature_pred = torch.from_numpy(model_LR.predict(Z_ind_eval))
print(temperature_pred)
temperature_pred = temperature_pred.unsqueeze(1).expand(Logits_ind_eval.shape[0], Logits_ind_eval.shape[1]).cuda()
print(temperature_pred)
Logits_ind_eval = torch.from_numpy(Logits_ind_eval).cuda()
Y_ind_eval = torch.from_numpy(Y_ind_eval).cuda()
ECE_MDTS = ece_criterion((Logits_ind_eval / temperature_pred), Y_ind_eval)
print('Ind ECE (MD-TS): ', ECE_MDTS.item())