import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import time
import pathlib
import os
import pickle
from tqdm import tqdm
import pdb
import torchvision.transforms as trn
def sort_sum(scores):
    I = scores.argsort(axis=1)[:,::-1]
    ordered = np.sort(scores,axis=1)[:,::-1]
    cumsum = np.cumsum(ordered,axis=1) 
    return I, ordered, cumsum

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def validate(val_loader, model, print_bool):
    with torch.no_grad():
        batch_time = AverageMeter('batch_time')
        top1 = AverageMeter('top1')
        top5 = AverageMeter('top5')
        coverage = AverageMeter('RAPS coverage')
        size = AverageMeter('RAPS size')
        # switch to evaluate mode
        model.eval()
        end = time.time()
        N = 0
        for i, (x, target) in enumerate(val_loader):
            target = target.cuda()
            # compute output
            output, S = model(x.cuda())
            if output.shape[1]<5:
                large_k=1
            else:
                large_k =5
            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, large_k))
            cvg, sz = coverage_size(S, target)

            # Update meters
            top1.update(prec1.item()/100.0, n=x.shape[0])
            top5.update(prec5.item()/100.0, n=x.shape[0])
            coverage.update(cvg, n=x.shape[0])
            size.update(sz, n=x.shape[0])

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            N = N + x.shape[0]
            if print_bool:
                print(f'\rN: {N} | Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) | Cvg@1: {top1.val:.3f} ({top1.avg:.3f}) | Cvg@5: {top5.val:.3f} ({top5.avg:.3f}) | Cvg@RAPS: {coverage.val:.3f} ({coverage.avg:.3f}) | Size@RAPS: {size.val:.3f} ({size.avg:.3f})', end='')
    if print_bool:
        print('') #Endline

    return top1.avg, top5.avg, coverage.avg, size.avg 

def coverage_size(S,targets):
    covered = 0
    size = 0
    for i in range(targets.shape[0]):
        if (targets[i].item() in S[i]):
            covered += 1
        size = size + S[i].shape[0]
    return float(covered)/targets.shape[0], size/targets.shape[0]

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def data2tensor(data):
    imgs = torch.cat([x[0].unsqueeze(0) for x in data], dim=0).cuda()
    targets = torch.cat([torch.Tensor([int(x[1])]) for x in data], dim=0).long()
    return imgs, targets

def split2ImageFolder(path, transform, n1, n2):
    dataset = torchvision.datasets.ImageFolder(path, transform)
    data1, data2 = torch.utils.data.random_split(dataset, [n1, len(dataset)-n1])
    data2, _ = torch.utils.data.random_split(data2, [n2, len(dataset)-n1-n2])
    return data1, data2

def split2(dataset, n1, n2):
    data1, temp = torch.utils.data.random_split(dataset, [n1, dataset.tensors[0].shape[0]-n1])
    data2, _ = torch.utils.data.random_split(temp, [n2, dataset.tensors[0].shape[0]-n1-n2])
    # data2, _ = torch.utils.data.random_split(temp, [n2, dataset.tensors[0].shape[0]-n1-n2],torch.Generator())
    return data1, data2,data1.indices,np.array(temp.indices)[np.array(data2.indices)].tolist()

# Computes logits and targets from a model and loader
def get_logits_targets_CLIP(model,datasetname, dataset,num_classes,mask=None):
    print(f'Computing logits for model (only happens once).')
    if isinstance(mask,np.ndarray):
        num_classes = len(mask)
    elif mask==None:
        mask=np.arange(num_classes)
    else:
        raise NotImplementedError
    from models import clip
    import json
    if datasetname == "imagenet" or datasetname == "imagenetv2":
        usr_dir = os.path.expanduser('~')
        data_dir = os.path.join(usr_dir,"data","imagenet")
        with open(os.path.join(data_dir,'human_readable_labels.json')) as f:
            readable_labels = json.load(f)
    else:
        readable_labels = dataset.classes
    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in readable_labels]).cuda()
    model, preprocess = model[0],model[1]
    # Calculate features
    with torch.no_grad():
        text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    logits_list = []
    labels_list = []
    mask = torch.tensor(mask)
    mask = mask.squeeze()
    with torch.no_grad():
        for  index in tqdm(range(len(dataset))):
        # for  index in tqdm(range(10)):
            tmp_x, tmp_label = dataset[index]  
            tmp_x = preprocess(tmp_x).unsqueeze(0)
            image_features = model.encode_image(tmp_x.cuda()) 
            image_features /= image_features.norm(dim=-1, keepdim=True)
            tmp_logits = (100.0 * image_features @ text_features.T)
            tmp_logits = tmp_logits.detach().cpu()
            # Check if each element of 'label' is present in 'mask' and get the indices
            mask_indices = torch.where(torch.isin(tmp_label, mask))[0]
            
            # Find the indices of elements in array_m that match elements in array_N
            if len(mask_indices)==1:
                logits_list.append(tmp_logits[mask_indices,:][:,mask])
                indices = torch.nonzero(mask == tmp_label, as_tuple=False)
                indices = indices.reshape(-1)
                labels_list.append(indices)
            else:
                continue
        logits = torch.cat(logits_list)
        labels = torch.cat(labels_list)
    dataset_logits = torch.utils.data.TensorDataset(logits, labels.long()) 
    return dataset_logits

# Computes logits and targets from a model and loader
def get_logits_targets(model, loader,num_classes,mask=None,modelname=""):
    print(f'Computing logits for model (only happens once).')
    if isinstance(mask,np.ndarray):
        num_classes = len(mask)
    elif mask==None:
        mask=np.arange(num_classes)
    else:
        raise NotImplementedError

    
    logits_list = []
    labels_list = []
    mask = torch.tensor(mask)
    mask = mask.squeeze()
    with torch.no_grad():
        for  examples in tqdm(loader):
            tmp_x, tmp_label = examples[0], examples[1]            
            tmp_logits = model(tmp_x.cuda()).detach().cpu()

            # Check if each element of 'label' is present in 'mask' and get the indices
            mask_indices = torch.where(torch.isin(tmp_label, mask))[0]
            
            logits_list.append(tmp_logits[mask_indices,:][:,mask])
            

            # Find the indices of elements in array_m that match elements in array_N
            indices = torch.nonzero(mask == tmp_label[mask_indices][:, None], as_tuple=False)[:,1]
            labels_list.append(indices)
        # print(len(logits_list))
        # print(examples[0])
        logits = torch.cat(logits_list)
        labels = torch.cat(labels_list)
        
    
    # Construct the dataset
    dataset_logits = torch.utils.data.TensorDataset(logits, labels.long()) 
    
    
    
    return dataset_logits


from models.connetor import build_common_model
from datasets.connector import build_dataset


def check_transform(model_name,dataset_name):
    if model_name == "ViT" or model_name == "Inception":
    # if model_name == "ViT":
        if dataset_name == "cifar10":
            mean = (0.492, 0.482, 0.446)
            std = (0.247, 0.244, 0.262)
            transform = trn.Compose([transforms.Resize(224),
                                    trn.ToTensor(), 
                                  trn.Normalize(mean, std)])
            return transform
        elif dataset_name =="cifar100":
            CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
            CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
            if model_name == "ViT":
                transform = trn.Compose([
                                        transforms.Resize(224),
                                        trn.ToTensor(), 
                                    trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            elif model_name == "Inception":
                transform = trn.Compose([
                                        trn.ToTensor(), 
                                    trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            return transform
        else:
            return None
    else:
        return None

def get_logits_dataset(modelname, datasetname,mask=None,category=""):
    if category=="hua":
        cache = os.path.join(os.getcwd(),"data_hua",datasetname,'pkl')
    else:
        cache = os.path.join(os.getcwd(),"data",datasetname,'pkl')
    if not os.path.exists(cache):
        os.mkdir(cache)
    fname = cache +'/' + modelname + '.pkl' 
    # If the file exists, load and return it.
    if os.path.exists(fname):
        with open(fname, 'rb') as handle:
            return pickle.load(handle)
    if modelname == "CLIP":
        transform = lambda x:x
        dataset,num_classes = build_dataset(datasetname,"test",transform)
    else:
        transform = check_transform(modelname,datasetname)
        dataset,num_classes = build_dataset(datasetname,"test",transform)
    model = build_common_model(modelname,datasetname,category="hua")
    # train_set,validate_set=torch.utils.data.random_split(dataset,[bsz,len(dataset)-bsz])

    
    if modelname == "CLIP":
        # Get the logits and targets
        dataset_logits = get_logits_targets_CLIP(model, datasetname,dataset, num_classes, mask)
    else:
        # Get the logits and targets
        loader = torch.utils.data.DataLoader(dataset, batch_size = 320, shuffle=False, pin_memory=True)
        dataset_logits = get_logits_targets(model, loader,num_classes,mask)

    # Save the dataset 
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as handle:
        pickle.dump(dataset_logits, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
        
    

    return dataset_logits

def get_logits_dataset_mask(modelname, datasetname,mask=None):
    dataset,num_classes = build_dataset(datasetname,"test")
    model = build_common_model(modelname,datasetname)

    loader = torch.utils.data.DataLoader(dataset, batch_size = 320, shuffle=False, pin_memory=True)

    # Get the logits and targets
    dataset_logits = get_logits_targets(model, loader,num_classes,mask)
    return dataset_logits

# Computes logits and targets from a model and loader
def getFinalLayer(model, loader,modelname,dataset_name):
    print(f'Extract embdedding x for model (only happens once).')
    results = [] # 1000 classes in Imagenet.
    if modelname in ["ResNeXt101","ResNet152","ResNet101","ResNet50","ResNet18"]:
        def hook_fn(module, input, output):
            results.append(output.clone())
        if dataset_name=="cifar100":
            layer = model.avg_pool
        else:
            layer = model.avgpool
        layer.register_forward_hook(hook_fn)
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        if modelname in ["ResNet18"]:
            outputsize= 512
        else:
            outputsize = 2048
        results = torch.reshape(torch.cat(results, dim=0),(-1,outputsize)).detach().cpu().numpy()

    elif modelname =="DenseNet161":
        def hook_fn(module, input, output):
            results.append(input[0].clone())  
        layer = model.classifier
        layer.register_forward_hook(hook_fn)
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,2208)).detach().cpu().numpy()

    elif modelname == "VGG16":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.classifier[6]
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,4096)).detach().cpu().numpy()
        

    elif modelname =="Inception":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        if dataset_name == "cifar100":
            layer = model.linear
        else:
            layer = model.fc
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,2048)).detach().cpu().numpy()
    elif modelname =="ShuffleNet":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.fc
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,1024)).detach().cpu().numpy()

    elif modelname =="ViT":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.heads
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        results = torch.reshape(torch.cat(results, dim=0),(-1,768)).detach().cpu().numpy()
    elif modelname =="DeiT":
        results1 = []
        results_dist = []
        def hook_fn1(module, input, output):
            results1.append(input[0].clone())
        def hook_fn2(module, input, output):
            results_dist.append(input[0].clone())
        model.head.register_forward_hook(hook_fn1)  
        model.head_dist.register_forward_hook(hook_fn2)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        results1 = torch.reshape(torch.cat(results1, dim=0),(-1,768)).detach().cpu().numpy()
        results_dist = torch.reshape(torch.cat(results_dist, dim=0),(-1,768)).detach().cpu().numpy()
        results = np.concatenate((results1,results_dist),axis=1)

    else:
        raise NotImplementedError
    
    return results


def get_featureX_dataset(modelname, datasetname):
    cache = os.path.join(os.getcwd(),"data",datasetname,'pkl')
    if not os.path.exists(cache):
        os.mkdir(cache)
    fname = cache +'/' + modelname + '_featureX.pkl' 
    # If the file exists, load and return it.
    if os.path.exists(fname):
        with open(fname, 'rb') as handle:
            return pickle.load(handle)
    transform = check_transform(modelname,datasetname)
    dataset,_ = build_dataset(datasetname,"test",transform)
    loader = torch.utils.data.DataLoader(dataset, batch_size = 320, shuffle=False, pin_memory=True)
    model = build_common_model(modelname,datasetname)
    # Get the logits and targets
    embedding_X = getFinalLayer(model, loader,modelname,datasetname)
    # print(embedding_X.shape)
    # Save the dataset 
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as handle:
        pickle.dump(embedding_X, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return embedding_X


from .loss_function import _ECELoss

def  get_optimal_parameters(transformation,calib_loader,device):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """

        transformation.to(device)
        nll_criterion = nn.CrossEntropyLoss().to(device)
        ece_criterion = _ECELoss().to(device)
        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        with torch.no_grad():
            for batch_idx, examples in enumerate(calib_loader):
                logits, label = examples[0], examples[1]
                logits_list.append(logits)
                labels_list.append(label)
            # print(len(logits_list))
            # print(examples[0])
            logits = torch.cat(logits_list).to(device)
            labels = torch.cat(labels_list).to(device)

        # Calculate NLL and ECE before temperature scaling
        before_temperature_O_nll = nll_criterion(logits, labels).item()
        before_temperature_nll = nll_criterion(transformation(logits), labels).item()
        before_temperature_ece = ece_criterion(transformation(logits), labels).item()
        before_temperature_O_ece = ece_criterion(logits, labels).item()
        # print('Before temperature - OriginalNLL: %.3f, OriginalECE: %.3f   NLL: %.3f, ECE: %.3f' % (before_temperature_O_nll,before_temperature_O_ece,before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        # optimizer = optim.Adam(transformation.parameters(), lr=0.001,weight_decay=0.9)
        optimizer = optim.LBFGS(transformation.parameters(), lr=0.1, max_iter=50)

        def eval():
            optimizer.zero_grad()
            loss = nll_criterion(transformation(logits), labels)
            loss.backward()
            return loss
        optimizer.step(eval)

        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(transformation(logits), labels).item()
        after_temperature_ece = ece_criterion(transformation(logits), labels).item()
        # print('After temperature - NLL: %.3f, ECE: %.6f ' % (after_temperature_nll, after_temperature_ece))


        
        return transformation
  
  
from sklearn.neighbors import KDTree
import numpy as np
  
def computeInputyAtypicalityKNN(cal_Data,test_Data,k=5):
    tree = KDTree(cal_Data)


    distances, _ = tree.query(test_Data, k=k)

    mean_distance = np.mean(distances, axis=1)
    return mean_distance

def compute_topn_accuracy(logits, target, n):
    topn_indices = np.argsort(logits, axis=1)[:, -n:]

    correct_topn = np.any(topn_indices == target[:, np.newaxis], axis=1)

    topn_accuracy = np.mean(correct_topn)

    return topn_accuracy