import numpy as np
import pickle
from scipy.special import softmax
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as tdata
import pandas as pd
import time
from tqdm import tqdm
from utils import validate, validate_logits, validate_adv_logits, get_logits_targets, get_logits_targets_adv, sort_sum
import pdb

# Conformalize a model with a calibration set.
# Save it to a file in .cache/modelname
# The only difference is that the forward method of ConformalModel also outputs a set.

def convert_to_multi_hot(integers):
    max_value = 1000  # Determine the length of the vector
    size_multi_hot = len(integers)
    multi_hot = np.zeros((size_multi_hot,max_value))  # Initialize a vector of zeros
    
    # Set the indices in the vector to 1
    for i in range(size_multi_hot):
        #if np.min(integers[i]) <= 1:
        #    print(np.min(integers[i]))
        multi_hot[i,np.array(integers[i])] = 1
    
    return multi_hot

class ConformalModel(nn.Module):
    def __init__(self, calib_loader, args, alpha, kreg=None, lamda=None, randomized=True, allow_zero_sets=False, pct_paramtune = 0.3, batch_size=32, lamda_criterion='size'):
        super(ConformalModel, self).__init__()
        self.alpha = alpha
        self.T = torch.Tensor([1.0]) #initialize (1.3 is usually a good value)
        self.T, calib_logits = platt(calib_loader, args)
        self.randomized=randomized
        self.allow_zero_sets=allow_zero_sets
        self.num_classes = args.num_classes

        if kreg == None or lamda == None:
            print('RAPS')
            kreg, lamda, calib_logits = pick_parameters(args, calib_logits, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion)

        self.penalties = np.zeros((1, self.num_classes))
        self.penalties[:, kreg:] += lamda 

        calib_loader = tdata.DataLoader(calib_logits, batch_size = args.batch_size, shuffle=True, pin_memory=True)

        self.T, self.bias = platt_pss(calib_loader, args, alpha, kreg, lamda, randomized, allow_zero_sets)


        self.Qhat = conformal_calibration_logits(self, calib_loader)
        print('Qhat')
        print(self.Qhat)
    
    #def forward_no_set(self, *args, randomized=None, allow_zero_sets=None, **kwargs):
    #    logits = self.model(*args, **kwargs)
    #    return logits
    
    def forward(self, logits, *args, randomized=None, allow_zero_sets=None, **kwargs):
        if randomized == None:
            randomized = self.randomized
        if allow_zero_sets == None:
            allow_zero_sets = self.allow_zero_sets
        
        with torch.no_grad():
            logits_numpy = logits.detach().cpu().numpy()
            scores = softmax(np.matmul(logits_numpy, self.T.numpy())+self.bias.numpy(), axis=1)

            I, ordered, cumsum = sort_sum(scores)

            S_list = gcq(scores, self.Qhat, I=I, ordered=ordered, cumsum=cumsum, penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets)
            #if True:
            #    S_list = []
            #    for j in range(300):
            #        S_list.append(gcq(scores, self.Qhat, I=I, ordered=ordered, cumsum=cumsum, penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets))
            #        #S_list.append(gcq(scores, self.Qhat*(0.97+0.0005*j), I=I, ordered=ordered, cumsum=cumsum, penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets))


        return logits, S_list

# Computes the conformal calibration
# def conformal_calibration(cmodel, calib_loader):
#     print("Conformal calibration")
#     with torch.no_grad():
#         E = np.array([])
#         for x, targets in tqdm(calib_loader):
#             logits = cmodel.model(x.cuda()).detach().cpu().numpy()
#             scores = softmax(logits/cmodel.T.item(), axis=1)

#             I, ordered, cumsum = sort_sum(scores)

#             E = np.concatenate((E,giq(scores,targets,I=I,ordered=ordered,cumsum=cumsum,penalties=cmodel.penalties,randomized=True, allow_zero_sets=True)))
            
#         Qhat = np.quantile(E,1-cmodel.alpha,interpolation='higher')

#         return Qhat 

# Temperature scaling
def platt(calib_logits, args, max_iters=10, lr=0.01, epsilon=0.01):
    print("Begin Platt scaling.")
    # Save logits so don't need to double compute them
    logits_dataset = calib_logits
    logits_loader = torch.utils.data.DataLoader(logits_dataset, batch_size = args.batch_size, shuffle=False, pin_memory=True)

    #T = platt_logits(logits_loader, max_iters=max_iters, lr=lr, epsilon=epsilon)

    T=torch.Tensor([1.0]).cuda()
    print(f"Optimal T={T.item()}")
    
    return T, logits_dataset 

def platt_pss(calib_loader, args, alpha, kreg, lamda, randomized, allow_zero_sets, max_iters=10, lr=0.001, epsilon=0.01):
    print("Begin Platt scaling.")
    # This is a bi-level optimization problem

    # iterate for 5 rounds
    # initialization
    T = torch.empty(args.num_classes,args.num_classes)
    torch.nn.init.eye_(T)
    #torch.nn.init.kaiming_normal_(T, mode='fan_out')
    print(T)
    bias = torch.zeros(1,args.num_classes)
    lr = args.lr_opt

    for i_iter in range(args.round_opt):
        #lr = lr*0.3
        T_init = T.numpy()
        conformal_model = ConformalModelLogits(calib_loader, args, T, bias, alpha=alpha, kreg=kreg, lamda=lamda, randomized=randomized, allow_zero_sets=allow_zero_sets, naive=False)    
        T, bias = platt_logits_pss(conformal_model, calib_loader, args, T, bias, max_iters=max_iters, lr=lr, epsilon=epsilon)
        print('round '+str(i_iter))
        print(abs(T-T_init).sum())
    #T=torch.Tensor([1.0]).cuda()
    #print(f"Optimal T={T.numpy()}")
    #print(f"Optimal T={bias.numpy()}")
    
    return T, bias

"""


        INTERNAL FUNCTIONS


"""

### Precomputed-logit versions of the above functions.

class ConformalModelLogits(nn.Module):
    def __init__(self, calib_loader, args, T=None, bias=None, alpha=None, kreg=None, lamda=None, randomized=True, allow_zero_sets=False, naive=False, LAC=False, pct_paramtune = 0.3, batch_size=32, lamda_criterion='size'):
        super(ConformalModelLogits, self).__init__()
        self.alpha = alpha
        self.randomized = randomized
        self.LAC = LAC
        self.allow_zero_sets = allow_zero_sets
        #self.T = platt_logits(calib_loader)
        if T == None:
            self.T = 1.3*torch.ones(1,args.num_classes)
        else:
            self.T = T
        if bias == None:
            self.bias = torch.zeros(1,args.num_classes)
        else:
            self.bias = bias


        if (kreg == None or lamda == None) and not naive and not LAC:
            kreg, lamda, calib_logits = pick_parameters(args, calib_loader.dataset, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion)
            calib_loader = tdata.DataLoader(calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)

        self.penalties = np.zeros((1, calib_loader.dataset[0][0].shape[0]))
        if not (kreg == None) and not naive and not LAC:
            self.penalties[:, kreg:] += lamda
        self.Qhat = 1-alpha
        if not naive and not LAC:
            self.Qhat = conformal_calibration_logits(self, calib_loader)
        elif not naive and LAC:
            gt_locs_cal = np.array([np.where(np.argsort(x[0]).flip(dims=(0,)) == x[1])[0][0] for x in calib_loader.dataset])
            scores_cal = 1-np.array([np.sort(torch.softmax(calib_loader.dataset[i][0]/self.T+self.bias, dim=0))[::-1][gt_locs_cal[i]] for i in range(len(calib_loader.dataset))]) 
            self.Qhat = np.quantile( scores_cal , np.ceil((scores_cal.shape[0]+1) * (1-alpha)) / scores_cal.shape[0] )

    def forward_no_set(self, logits, randomized=None, allow_zero_sets=None):
        return logits
    
    def forward(self, logits, randomized=None, allow_zero_sets=None):
        if randomized == None:
            randomized = self.randomized
        if allow_zero_sets == None:
            allow_zero_sets = self.allow_zero_sets
        
        with torch.no_grad():
            logits_numpy = logits.detach().cpu().numpy()
            #scores = softmax(logits_numpy/self.T.numpy() + self.bias.numpy(), axis=1)
            scores = softmax(logits_numpy, axis=1)

            if not self.LAC:
                I, ordered, cumsum = sort_sum(scores)

                S = gcq(scores, self.Qhat, I=I, ordered=ordered, cumsum=cumsum, penalties=self.penalties, randomized=randomized, allow_zero_sets=allow_zero_sets)
            else:
                S = [ np.where( (1-scores[i,:]) < self.Qhat )[0] for i in range(scores.shape[0]) ]
        #print(S)
        return logits, S

def conformal_calibration_logits(cmodel, calib_loader):
    with torch.no_grad():
        E = np.array([])
        for logits, targets in calib_loader:
            logits = logits.detach().cpu().numpy()

            #scores = softmax(logits/cmodel.T.numpy() + cmodel.bias.numpy(), axis=1)
            scores = softmax(np.matmul(logits, cmodel.T.numpy()) + cmodel.bias.numpy(), axis=1)

            I, ordered, cumsum = sort_sum(scores)

            E = np.concatenate((E,giq(scores,targets,I=I,ordered=ordered,cumsum=cumsum,penalties=cmodel.penalties,randomized=True,allow_zero_sets=True)))
            
        Qhat = np.quantile(E,1-cmodel.alpha,interpolation='higher')

        print('Qhat:')
        print(Qhat)
        return Qhat 

def platt_logits(calib_loader, max_iters=10, lr=0.01, epsilon=0.01):
    nll_criterion = nn.CrossEntropyLoss().cuda()

    T = nn.Parameter(torch.Tensor([1.3]).cuda())

    optimizer = optim.SGD([T], lr=lr)
    print(max_iters)
    for iter in range(max_iters):
        T_old = T.item()
        for x, targets in calib_loader:
            optimizer.zero_grad()
            x = x.cuda()
            x.requires_grad = True
            out = x/T
            loss = nll_criterion(out, targets.long().cuda())
            loss.backward()
            optimizer.step()
        if abs(T_old - T.item()) < epsilon:
            break
    return T 

def platt_logits_pss(cmodel, calib_loader, args, T, bias, max_iters=10, lr=0.01, epsilon=0.01):

    T_para = nn.Parameter(T.cuda())
    bias_para = nn.Parameter(bias.cuda())

    T_default = torch.eye(args.num_classes,args.num_classes).cuda()

    optimizer = optim.SGD([T_para,bias_para], lr=lr)
    nll_criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    print(lr)
    
    for iter in range(max_iters):
        T_old = T_para.cpu().detach().numpy()
        bias_old = bias_para.cpu().detach().numpy()
        expect_acc_all = np.zeros(10000)
        pss_all = np.zeros(10000)
        calib_target_all = np.zeros(10000)
        count_all = 0


        for x, targets in calib_loader:
            optimizer.zero_grad()
            x = x.cuda()
            x.requires_grad = True
            # compute expectation accuracy
            out = torch.matmul(x, T_para) + bias_para
            output_val, S_list = cmodel(out)

            loss_ps = nll_criterion(output_val, targets.long().cuda())
            #print(loss_ps.size())

            entropy = -torch.sum(F.log_softmax(output_val, dim=1) * F.softmax(output_val, dim=1), dim=1)
            #print(output_val.sum())
            cp_mask = torch.from_numpy(convert_to_multi_hot(S_list)).cuda()
            mask_logits = torch.ones(output_val.size()).cuda() * (1. - cp_mask) * (-1e5) + output_val * cp_mask
            output_val_mask = F.softmax(mask_logits, dim=1)
            #print(output_val_mask[0])
            #time.sleep(5)
            sampling_prob = output_val_mask**args.lambda_temp_sample/torch.sum(output_val_mask**args.lambda_temp_sample,dim=1,keepdim=True)
            #print(sampling_prob[0])
            #print(output_val_mask[0])
            #time.sleep(10)
            expect_acc = torch.sum(sampling_prob*output_val_mask, dim=1)

            #print(expect_acc.max())

            # calibration target
            # Dirichlet
            #calib_target = (args.lambda_k+args.lambda_temp_sample)/(args.lambda_k*torch.sum(cp_mask, dim=1)+args.lambda_temp_sample)

            # power function
            #print(torch.sum(cp_mask, dim=1).mean())
            #time.sleep(5)

            pss = torch.sum(cp_mask, dim=1)
            #print(pss.size())

            #print(args.lambda_k)

            calib_target = 1./torch.pow(pss, args.lambda_k)

            #print(pss)

            #print(calib_target)

            # compute loss
            # v1

            # v1 wd
            

            # v3
            

            pss_weight = torch.zeros(pss.size()).cuda()
            pss_weight[pss<=1] = 1.

            #pss_weight = torch.ones(pss.size()).cuda()


            pss_weight_gap = torch.ones(pss.size()).cuda()
            pss_weight_gap[pss>=400] = 0.
            
            #pss_weight_gap[pss==1] = 1.
            #pss_weight_gap[pss==2] = 1.
            #pss_weight_gap[pss==3] = 1.
            #print(pss_weight_gap)
            #pss_weight_gap[pss==4] = 9.
            #pss_weight_gap[pss==5] = 12.
            #print('pss')
            # print((pss==1).sum())
            # print((pss==2).sum())
            # print((pss==3).sum())
            # print((pss==4).sum())
            # print((pss==4).sum())

            loss = (( calib_target - expect_acc).square()*pss_weight_gap).sum()/len(x) + args.lambda_ent*(loss_ps*pss_weight).sum()/len(x) + 0.0001*(torch.square(T_para-T_default).sum()+torch.square(bias_para).sum())
            # v2

            
            #loss = ( calib_target - expect_acc).abs().mean() + output_val

            
            #loss = torch.sum(( calib_target - expect_acc).abs() * torch.sqrt(pss)) / torch.sum(torch.sqrt(pss))

            # v3

            #v3
            # pss_weight = torch.ones(pss.size()).cuda()
            # pss_weight[pss>100] = 0.
            # pss_weight[pss<=100] = 1/torch.sqrt(pss[pss<=100].float())
            # #pss_weight[pss==1] = 0.
            # loss = ((calib_target - expect_acc).abs()*pss_weight).sum()/pss_weight.sum() 

            #print(loss)
            #time.sleep(2)
            
            #loss = nll_criterion(out, targets.long().cuda())
            loss.backward()
            optimizer.step()
            #expect_acc_all[count_all:count_all+len(pss)]=expect_acc.cpu().detach().numpy()
            #pss_all[count_all:count_all+len(pss)]=pss.cpu().detach().numpy()
            #calib_target_all[count_all:count_all+len(pss)]=calib_target.cpu().detach().numpy()
            #count_all += len(pss)
        #pickle.dump({'expect_acc_all':expect_acc_all, 'pss_all':pss_all, 'calib_target_all':calib_target_all},open('opt_record.pkl', 'wb'))

        if abs(T_old - T_para.cpu().detach().numpy()).sum() + abs(bias_old - bias_para.cpu().detach().numpy()).sum() < epsilon:
            break
    print(loss)
    return T_para.cpu().data, bias_para.cpu().data

### CORE CONFORMAL INFERENCE FUNCTIONS

# Generalized conditional quantile function.
def gcq(scores, tau, I, ordered, cumsum, penalties, randomized, allow_zero_sets):
    penalties_cumsum = np.cumsum(penalties, axis=1)
    sizes_base = ((cumsum + penalties_cumsum) <= tau).sum(axis=1) + 1  # 1 - 1001
    sizes_base = np.minimum(sizes_base, scores.shape[1]) # 1-1000

    if randomized:
        V = np.zeros(sizes_base.shape)
        for i in range(sizes_base.shape[0]):
            V[i] = 1/ordered[i,sizes_base[i]-1] * \
                    (tau-(cumsum[i,sizes_base[i]-1]-ordered[i,sizes_base[i]-1])-penalties_cumsum[0,sizes_base[i]-1]) # -1 since sizes_base \in {1,...,1000}.

        sizes = sizes_base - (np.random.random(V.shape) >= V).astype(int)
    else:
        sizes = sizes_base

    if tau == 1.0:
        sizes[:] = cumsum.shape[1] # always predict max size if alpha==0. (Avoids numerical error.)

    if not allow_zero_sets:
        sizes[sizes == 0] = 1 # allow the user the option to never have empty sets (will lead to incorrect coverage if 1-alpha < model's top-1 accuracy

    S = list()

    # Construct S from equation (5)
    for i in range(I.shape[0]):
        S = S + [I[i,0:sizes[i]],]

    return S

# Get the 'p-value'
def get_tau(score, target, I, ordered, cumsum, penalty, randomized, allow_zero_sets): # For one example
    idx = np.where(I==target)
    tau_nonrandom = cumsum[idx]

    if not randomized:
        return tau_nonrandom + penalty[0]
    
    U = np.random.random()

    if idx == (0,0):
        if not allow_zero_sets:
            return tau_nonrandom + penalty[0]
        else:
            return U * tau_nonrandom + penalty[0] 
    else:
        return U * ordered[idx] + cumsum[(idx[0],idx[1]-1)] + (penalty[0:(idx[1][0]+1)]).sum()

# Gets the histogram of Taus. 
def giq(scores, targets, I, ordered, cumsum, penalties, randomized, allow_zero_sets):
    """
        Generalized inverse quantile conformity score function.
        E from equation (7) in Romano, Sesia, Candes.  Find the minimum tau in [0, 1] such that the correct label enters.
    """
    E = -np.ones((scores.shape[0],))
    for i in range(scores.shape[0]):
        E[i] = get_tau(scores[i:i+1,:],targets[i].item(),I[i:i+1,:],ordered[i:i+1,:],cumsum[i:i+1,:],penalties[0,:],randomized=randomized, allow_zero_sets=allow_zero_sets)

    return E

### AUTOMATIC PARAMETER TUNING FUNCTIONS
def pick_kreg(paramtune_logits, alpha):
    gt_locs_kstar = np.array([np.where(np.argsort(x[0]).flip(dims=(0,)) == x[1])[0][0] for x in paramtune_logits])
    kstar = np.quantile(gt_locs_kstar, 1-alpha, interpolation='higher') + 1
    return kstar 

def pick_lamda_size(args, paramtune_loader, alpha, kreg, randomized, allow_zero_sets):
    # Calculate lamda_star
    best_size = iter(paramtune_loader).__next__()[0][1].shape[0] # number of classes 
    # Use the paramtune data to pick lamda.  Does not violate exchangeability.
    for temp_lam in [0.001, 0.01, 0.1, 0.2, 0.5]: # predefined grid, change if more precision desired.
        conformal_model = ConformalModelLogits(paramtune_loader, args, alpha=alpha, kreg=kreg, lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, naive=False)
        if args.use_adv_calib == True:
            top1_avg, top5_avg, cvg_avg, sz_avg = validate_adv_logits(paramtune_loader, conformal_model, args, print_bool=False)
        else:
            top1_avg, top5_avg, cvg_avg, sz_avg = validate_logits(paramtune_loader, conformal_model, print_bool=False)
        if sz_avg < best_size:
            best_size = sz_avg
            lamda_star = temp_lam
    return lamda_star

def pick_lamda_adaptiveness(paramtune_loader, alpha, kreg, randomized, allow_zero_sets, strata=[[0,1],[2,3],[4,6],[7,10],[11,100],[101,1000]]):
    # Calculate lamda_star
    lamda_star = 0
    best_violation = 1
    # Use the paramtune data to pick lamda.  Does not violate exchangeability.
    for temp_lam in [0, 1e-5, 1e-4, 8e-4, 9e-4, 1e-3, 1.5e-3, 2e-3]: # predefined grid, change if more precision desired.
        conformal_model = ConformalModelLogits(paramtune_loader, alpha=alpha, kreg=kreg, lamda=temp_lam, randomized=randomized, allow_zero_sets=allow_zero_sets, naive=False)
        curr_violation = get_violation(conformal_model, paramtune_loader, strata, alpha)
        if curr_violation < best_violation:
            best_violation = curr_violation 
            lamda_star = temp_lam
    return lamda_star

def pick_parameters(args, calib_logits, alpha, kreg, lamda, randomized, allow_zero_sets, pct_paramtune, batch_size, lamda_criterion):
    num_paramtune = int(np.ceil(pct_paramtune * len(calib_logits)))
    paramtune_logits, calib_logits = tdata.random_split(calib_logits, [num_paramtune, len(calib_logits)-num_paramtune])
    calib_loader = tdata.DataLoader(calib_logits, batch_size=batch_size, shuffle=False, pin_memory=True)
    paramtune_loader = tdata.DataLoader(paramtune_logits, batch_size=batch_size, shuffle=False, pin_memory=True)

    if kreg == None:
        kreg = pick_kreg(paramtune_logits, alpha)
    if lamda == None:
        if lamda_criterion == "size":
            lamda = pick_lamda_size(args, paramtune_loader, alpha, kreg, randomized, allow_zero_sets)
        elif lamda_criterion == "adaptiveness":
            lamda = pick_lamda_adaptiveness(paramtune_loader, alpha, kreg, randomized, allow_zero_sets)
    return kreg, lamda, calib_logits

def get_violation(cmodel, loader_paramtune, strata, alpha):
    df = pd.DataFrame(columns=['size', 'correct'])
    for logit, target in loader_paramtune:
        # compute output
        output, S = cmodel(logit) # This is a 'dummy model' which takes logits, for efficiency.
        # measure accuracy and record loss
        size = np.array([x.size for x in S])
        I, _, _ = sort_sum(logit.numpy()) 
        correct = np.zeros_like(size)
        for j in range(correct.shape[0]):
            correct[j] = int( target[j] in list(S[j]) )
        batch_df = pd.DataFrame({'size': size, 'correct': correct})
        df = df.append(batch_df, ignore_index=True)
    wc_violation = 0
    for stratum in strata:
        temp_df = df[ (df['size'] >= stratum[0]) & (df['size'] <= stratum[1]) ]
        if len(temp_df) == 0:
            continue
        stratum_violation = abs(temp_df.correct.mean()-(1-alpha))
        wc_violation = max(wc_violation, stratum_violation)
    return wc_violation # the violation

