import sys
import numpy as np
import scipy
import torch
import os
import torchvision.transforms as transforms
import torchvision.datasets as dsets
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import warnings
import argparse
import utils_model
import copy


np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)


def Asym_CE(output, label, expert):

    output = output.cuda()
    label = label.cuda()
    expert_pred = expert.predict(labels=label, input=[])
    expert_pred = torch.tensor(expert_pred)
    expert_pred = expert_pred.cuda()
    expert_correctness = expert_pred == label
    expert_correctness = torch.tensor(expert_correctness)

    output_probit = Asym_SM_trans(output)
    output_probit = output_probit.cuda()
    num_class = 100


    sm = output_probit[:, 0:num_class]
    bsm = output_probit[:, num_class]
    loss1 = torch.log(sm+1e-7)
    loss1 = -loss1.gather(-1, label.view(-1, 1))
    loss2 = -torch.mul(expert_correctness.float(), torch.log(bsm+1e-7))-torch.mul(1-expert_correctness.float(), torch.log(1-bsm+1e-7))
    return torch.mean(loss1+loss2)

def check_01d(model, loader, expert):

    with torch.no_grad():
        model = model.cuda()
        error_count = 0
        rejection_count = 0
        sum = 0
        num_class=100
        ECE_total=0
        ECE_classifier=0
        ECE_expert=0
        error_count_10 = 0
        error_count_20 = 0
        error_count_30 = 0     
        label2 = []
        expert_pred2 = []
        expert_correctness2 = []
        output2 = []

        for i, (data, label) in enumerate(loader):
            sum += len(label)
            data = data.cuda()
            label = label.cuda()
            label = label.long()
            expert_pred = expert.predict(labels=label, input=[])
            expert_pred = torch.tensor(expert_pred).cuda()
            expert_correctness = expert_pred == label
            expert_correctness = torch.tensor(expert_correctness)
            expert_correctness = expert_correctness.float()
            expert_correctness = expert_correctness.cuda()
            
            output = model(data).cuda()
            if i==0:
              label2 = label
              expert_pred2 = expert_pred
              expert_correctness2 = expert_correctness
              output2 = output
            else:
              label2 = torch.cat([label2,label],dim=0)
              expert_pred2 = torch.cat([expert_pred2,expert_pred],dim=0)
              expert_correctness2 = torch.cat([expert_correctness2, expert_correctness],dim=0)
              output2 = torch.cat([output2, output],dim=0)
        
        label = label2
        expert_pred = expert_pred2
        expert_correctness = expert_correctness2
        output = output2
        
        label=label.cuda()
        expert_pred=expert_pred.cuda()
        output=output.cuda()
        expert_correctness=expert_correctness.cuda()
    
      
        probit=Asym_SM_trans(output).cuda()
            
        expert_probit=probit[:,num_class]
        classifier_probit=probit[:,0:num_class]

        (value, prediction) = torch.max(output, dim=-1)
        prediction = prediction.cuda()
        acc_list = prediction != 100
        reject_list = prediction == 100
        
        error_list = prediction.view(-1, 1) != label.view(-1, 1)
        error_classifier = torch.mul(error_list.view(-1, 1), acc_list.view(-1, 1)).cuda()

        error_list_expert = expert_pred.view(-1, 1) != label.view(-1, 1)
        error_expert = torch.mul(error_list_expert.view(-1, 1), reject_list.view(-1, 1)).cuda()

        error_count += (torch.sum(error_classifier).cuda()+torch.sum(error_expert).cuda())
        rejection_count += torch.sum(reject_list).cuda()
        

        
        if torch.sum(reject_list)<int(len(label)/10):
          top_p10_list = reject_list
        else:
          deferred_probit = torch.mul(expert_probit, reject_list)
          top_p10_list=torch.topk(input=deferred_probit, k=int(len(label)/10), largest=True)[1]

        
        if torch.sum(reject_list)<2*int(len(label)/10):
          top_p20_list = reject_list
        else:
          deferred_probit = torch.mul(expert_probit, reject_list)
          top_p20_list=torch.topk(input=deferred_probit, k=2*int(len(label)/10), largest=True)[1]
          
        if torch.sum(reject_list)<3*int(len(label)/10):
          top_p30_list = reject_list
        else:
          deferred_probit = torch.mul(expert_probit, reject_list)
          top_p30_list=torch.topk(input=deferred_probit, k=3*int(len(label)/10), largest=True)[1]
        
                                 
        cand_top_10 = torch.zeros(len(label)).cuda()
        cand_top_20 = torch.zeros(len(label)).cuda()
        cand_top_30 = torch.zeros(len(label)).cuda()
        
        
        cand_top_10[top_p10_list] = 1
        cand_top_20[top_p20_list] = 1
        cand_top_30[top_p30_list] = 1
        
        reject_list_top10=torch.mul(reject_list, cand_top_10).cuda()
        reject_list_top20=torch.mul(reject_list, cand_top_20).cuda()
        reject_list_top30=torch.mul(reject_list, cand_top_30).cuda()
        
        error_classifier_top10 = torch.mul(error_list.view(-1, 1), (1-reject_list_top10).view(-1, 1)).cuda()  
        error_expert_top10 = torch.mul(error_list_expert.view(-1, 1), reject_list_top10.view(-1, 1)).cuda()
        error_classifier_top20 = torch.mul(error_list.view(-1, 1), (1-reject_list_top20).view(-1, 1)).cuda()  
        error_expert_top20 = torch.mul(error_list_expert.view(-1, 1), reject_list_top20.view(-1, 1)).cuda()
        error_classifier_top30 = torch.mul(error_list.view(-1, 1), (1-reject_list_top30).view(-1, 1)).cuda()  
        error_expert_top30 = torch.mul(error_list_expert.view(-1, 1), reject_list_top30.view(-1, 1)).cuda()
        error_count_10 += (torch.sum(error_classifier_top10).cuda()+torch.sum(error_expert_top10).cuda())
        error_count_20 += (torch.sum(error_classifier_top20).cuda()+torch.sum(error_expert_top20).cuda())
        error_count_30 += (torch.sum(error_classifier_top30).cuda()+torch.sum(error_expert_top30).cuda())
        
        
        
        classifier_index = torch.nonzero(acc_list).cuda()
        expert_index = torch.nonzero(reject_list).cuda()


        ECE_classifier+= sum*ECE(probits=classifier_probit, labels=label)
        
        binary_probit = torch.cat([torch.unsqueeze(1-expert_probit,dim=-1),torch.unsqueeze(expert_probit,dim=-1)],dim=-1)
        ECE_expert += sum*ECE(probits=binary_probit, labels=expert_correctness)
        


        ECE_selected_expert = 0
        ECE_selected_classifier=0
        
        if torch.sum(reject_list)>1:          
          selected_binary_probit = torch.cat([1-expert_probit[expert_index.view(-1,1)],expert_probit[expert_index.view(-1,1)]],dim=-1)
          ECE_selected_expert = torch.sum(reject_list)*ECE(probits=selected_binary_probit, labels=expert_correctness[expert_index.view(-1,1)].squeeze())
        
        
        if torch.sum(acc_list)>1:
          
          ECE_selected_classifier = torch.sum(acc_list)*ECE(probits=classifier_probit[classifier_index.view(-1,1),:].squeeze(), labels=label[classifier_index.view(-1,1)].squeeze())
          
        
        ECE_total = ECE_selected_expert+ECE_selected_classifier
        error_01d = error_count / sum
        coverage = 1-rejection_count / sum
        error_10 = error_count_10 / sum
        error_20 = error_count_20 / sum
        error_30 = error_count_30 / sum
        ECE_classifier = ECE_classifier / sum
        ECE_expert = ECE_expert / sum
        ECE_total = ECE_total/sum

    return error_01d, coverage, ECE_classifier, ECE_expert, ECE_total, error_10, error_20, error_30


def Asym_SM_trans(scorer):
    scorer = scorer.cuda()
    class_num = scorer.size()[1]-1
    classifier_input = scorer[:, 0:class_num]
    classifier_input = classifier_input.cuda()
    output1 = torch.softmax(classifier_input, dim=-1).cuda()

    sm = torch.softmax(scorer, dim=-1).cuda()
    rejector_output = sm[:, class_num].view(-1,1)
    norm = -(torch.max(sm[:, 0:class_num],dim=-1)[0].view(-1,1)-1)
    output2 = rejector_output/(norm+1e-7)
    return torch.cat((output1, output2), dim=-1)



def ECE(probits, labels):
    
    n_bins = 15
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    ece = torch.zeros(1, device=probits.device)
    confidences, predictions = torch.max(probits, 1)
    accuracies = predictions.eq(labels)
    
    confidences=torch.min(torch.ones(len(confidences)).cuda()-0.001,confidences.cuda())
    for bin_lower, bin_upper in zip(bin_lowers,bin_uppers):
        # Calculated |confidence - accuracy| in each bin
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

    return ece[0]