import os
import pickle
import time

import numpy as np
import torch
from tqdm import tqdm
import random

def fix_randomness( seed=0):
    ### Fix randomness 
    np.random.seed(seed=seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)



def sort_sum(scores):
    I = scores.argsort(axis=1)[:, ::-1]
    ordered = np.sort(scores, axis=1)[:, ::-1]
    cumsum = np.cumsum(ordered, axis=1)
    return I, ordered, cumsum


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def validate(val_loader, model, print_bool):
    with torch.no_grad():
        batch_time = AverageMeter('batch_time')
        top1 = AverageMeter('top1')
        top5 = AverageMeter('top5')
        coverage = AverageMeter('RAPS coverage')
        size = AverageMeter('RAPS size')
        # switch to evaluate mode
        model.eval()
        end = time.time()
        N = 0
        for i, (x, target) in enumerate(val_loader):
            target = target.cuda()
            # compute output
            output, S = model(x.cuda())
            if output.shape[1] < 5:
                large_k = 1
            else:
                large_k = 5
            # measure accuracy and record loss
            prec1, prec5 = accuracy(output, target, topk=(1, large_k))
            cvg, sz = coverage_size(S, target)

            # Update meters
            top1.update(prec1.item() / 100.0, n=x.shape[0])
            top5.update(prec5.item() / 100.0, n=x.shape[0])
            coverage.update(cvg, n=x.shape[0])
            size.update(sz, n=x.shape[0])

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            N = N + x.shape[0]
            if print_bool:
                print(
                    f'\rN: {N} | Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) | Cvg@1: {top1.val:.3f} ({top1.avg:.3f}) | Cvg@5: {top5.val:.3f} ({top5.avg:.3f}) | Cvg@RAPS: {coverage.val:.3f} ({coverage.avg:.3f}) | Size@RAPS: {size.val:.3f} ({size.avg:.3f})',
                    end='')
    if print_bool:
        print('')  # Endline

    return top1.avg, top5.avg, coverage.avg, size.avg


def coverage_size(S, targets):
    covered = 0
    size = 0
    for i in range(targets.shape[0]):
        if (targets[i].item() in S[i]):
            covered += 1
        size = size + S[i].shape[0]
    return float(covered) / targets.shape[0], size / targets.shape[0]


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].float().sum()
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def split2(dataset, n1, n2):
    data1, temp = torch.utils.data.random_split(dataset, [n1, dataset.tensors[0].shape[0] - n1])
    data2, _ = torch.utils.data.random_split(temp, [n2, dataset.tensors[0].shape[0] - n1 - n2])
    return data1, data2,data1.indices,np.array(temp.indices)[np.array(data2.indices)].tolist()


# Computes logits and targets from a model and loader
def get_logits_targets(model, loader):
    print(f'Computing logits for model (only happens once).')

    logits_list = []
    labels_list = []
    with torch.no_grad():
        for examples in tqdm(loader):
            tmp_x, tmp_label = examples[0], examples[1]
            tmp_logits = model(tmp_x.cuda()).detach().cpu()
            logits_list.append(tmp_logits)
            labels_list.append(tmp_label)

        logits = torch.cat(logits_list)
        labels = torch.cat(labels_list)
    # Construct the dataset
    dataset_logits = torch.utils.data.TensorDataset(logits, labels.long())

    return dataset_logits


from models.utils import build_common_model
from dataset.utils import build_dataset

import torchvision.transforms as trn


def check_transform(model_name, dataset_name):
    if model_name == "ViT" or model_name == "Inception":
        if dataset_name == "cifar10":
            mean = (0.492, 0.482, 0.446)
            std = (0.247, 0.244, 0.262)
            transform = trn.Compose([trn.Resize(224),
                                     trn.ToTensor(),
                                     trn.Normalize(mean, std)])
            return transform
        elif dataset_name == "cifar100":
            CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
            CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
            if model_name == "ViT":
                transform = trn.Compose([
                    trn.Resize(224),
                    trn.ToTensor(),
                    trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            elif model_name == "Inception":
                transform = trn.Compose([
                    trn.ToTensor(),
                    trn.Normalize(CIFAR100_TRAIN_MEAN, CIFAR100_TRAIN_STD)])
            return transform
        else:
            return None
    else:
        return None


def get_logits_dataset(model_name, dataset_name, mode = "test", bsz=128):
    cache = os.path.join(os.getcwd(), "cache", dataset_name, "logits")
    if not os.path.exists(cache):
        os.mkdir(cache)
    if mode == "test":
        fname = cache + '/' + model_name + '.pkl'
    elif mode == "train_eval":
        fname = cache + '/' + model_name + '_train.pkl'
    else:
        raise NotImplementedError
    transform = check_transform(model_name, dataset_name)
    dataset, num_classes = build_dataset(dataset_name, mode, transform)
    if os.path.exists(fname):
        with open(fname, 'rb') as handle:
            return pickle.load(handle), num_classes

    model = build_common_model(model_name, dataset_name)

    # Get the logits and targets
    loader = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, pin_memory=True)
    dataset_logits = get_logits_targets(model, loader)

    # Save the dataset 
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as handle:
        pickle.dump(dataset_logits, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return dataset_logits, num_classes


def get_labels_dataset(model_name, dataset_name, mode = "test", bsz=128):
    cache = os.path.join(os.getcwd(), "cache", dataset_name, "logits")
    if not os.path.exists(cache):
        os.mkdir(cache)
    if mode == "test":
        fname = cache + '/' + model_name + '_test__label.pkl'
    elif mode == "train_eval":
        fname = cache + '/' + model_name + '_train_label.pkl'
    else:
        raise NotImplementedError
    transform = check_transform(model_name, dataset_name)
    dataset, num_classes = build_dataset(dataset_name, mode, transform)
    if os.path.exists(fname):
        with open(fname, 'rb') as handle:
            return pickle.load(handle)

    # Get the logits and targets
    loader = torch.utils.data.DataLoader(dataset, batch_size=bsz, shuffle=False, pin_memory=True)
    print(f'Extracting labels for model (only happens once).')

    labels_list = []
    with torch.no_grad():
        for examples in tqdm(loader):
            tmp_x, tmp_label = examples[0], examples[1]
            labels_list.append(tmp_label)

        labels = torch.cat(labels_list)


    # Save the dataset 
    os.makedirs(os.path.dirname(fname), exist_ok=True)
    with open(fname, 'wb') as handle:
        pickle.dump(labels, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return labels


def get_features_dataset(modelname, datasetname, mode= "test", bsz=128, pc=0.2):
    
    cache = os.path.join(os.getcwd(),"cache",datasetname,'pkl')
    if not os.path.exists(cache):
        os.mkdir(cache)
    if mode == "test":
        fname = cache +'/' + modelname + '_features.pkl' 
        if os.path.exists(fname):
            with open(fname, 'rb') as handle:
                return pickle.load(handle)
        transform = check_transform(modelname,datasetname)
        dataset,_ = build_dataset(datasetname,mode,transform)
        loader = torch.utils.data.DataLoader(dataset, batch_size = bsz, shuffle=False, pin_memory=True)
        model = build_common_model(modelname,datasetname)
        # Get the logits and targets
        embedding_X = get_features(model, loader,modelname,datasetname)
        # Save the dataset 
        os.makedirs(os.path.dirname(fname), exist_ok=True)
        with open(fname, 'wb') as handle:
            pickle.dump(embedding_X, handle, protocol=pickle.HIGHEST_PROTOCOL)
        return embedding_X

    else:
        fname = cache +'/' + modelname + '_features_train.pkl' 
        # If the file exists, load and return it.
        if os.path.exists(fname):
            with open(fname, 'rb') as handle:
                return pickle.load(handle)
        transform = check_transform(modelname,datasetname)
        dataset,_ = build_dataset(datasetname,mode,transform)
        save_num = int(len(dataset)*pc)
        fix_randomness()
        dataset, _ = torch.utils.data.random_split(dataset, [save_num, len(dataset) - save_num])
        selected_indices = dataset.indices
        results = {}
        results["selected_indices"] = selected_indices
        loader = torch.utils.data.DataLoader(dataset, batch_size = bsz, shuffle=False, pin_memory=True)
        model = build_common_model(modelname,datasetname)
        # Get the logits and targets
        embedding_X = get_features(model, loader,modelname,datasetname)
        results["embedding"] = embedding_X
        # Save the dataset 
        os.makedirs(os.path.dirname(fname), exist_ok=True)
        with open(fname, 'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

        return results

# Computes logits and targets from a model and loader
def get_features(model, loader,modelname,dataset_name):
    print(f'Extract embdedding x for model (only happens once).')
    results = [] # 1000 classes in Imagenet.
    if modelname in ["ResNeXt101","ResNet152","ResNet101","ResNet50","ResNet18"]:
        def hook_fn(module, input, output):
            results.append(output.clone())
        if dataset_name=="cifar100":
            layer = model.avg_pool
        else:
            layer = model.avgpool
        layer.register_forward_hook(hook_fn)
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        if modelname in ["ResNet18"]:
            outputsize= 512
        else:
            outputsize = 2048
        results = torch.reshape(torch.cat(results, dim=0),(-1,outputsize)).detach().cpu().numpy()

    elif modelname =="DenseNet161":
          
        if dataset_name=="cifar100":
            def hook_fn(module, input, output):
                results.append(output.clone())
            layer = model.avgpool
        else:
            def hook_fn(module, input, output):
                results.append(input[0].clone())
            layer = model.classifier
        
        layer.register_forward_hook(hook_fn)
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,2208)).detach().cpu().numpy()

    elif modelname == "VGG16":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.classifier[6]
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,4096)).detach().cpu().numpy()
        

    elif modelname =="Inception":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        if dataset_name == "cifar100":
            layer = model.linear
        else:
            layer = model.fc
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,2048)).detach().cpu().numpy()
    elif modelname =="ShuffleNet":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.fc
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())  
        results = torch.reshape(torch.cat(results, dim=0),(-1,1024)).detach().cpu().numpy()

    elif modelname =="ViT":
        def hook_fn(module, input, output):
            results.append(input[0].clone())
        layer = model.heads
        layer.register_forward_hook(hook_fn)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        results = torch.reshape(torch.cat(results, dim=0),(-1,768)).detach().cpu().numpy()
    elif modelname =="DeiT":
        results1 = []
        results_dist = []
        def hook_fn1(module, input, output):
            results1.append(input[0].clone())
        def hook_fn2(module, input, output):
            results_dist.append(input[0].clone())
        model.head.register_forward_hook(hook_fn1)  
        model.head_dist.register_forward_hook(hook_fn2)  
        with torch.no_grad():
            for x, targets in tqdm(loader):
                batch_logits = model(x.cuda())
        results1 = torch.reshape(torch.cat(results1, dim=0),(-1,768)).detach().cpu().numpy()
        results_dist = torch.reshape(torch.cat(results_dist, dim=0),(-1,768)).detach().cpu().numpy()
        # print(results1.shape,results_dist.shape)
        results = np.concatenate((results1,results_dist),axis=1)
        # print(results.shape)
        # raise NotImplementedError
    else:
        raise NotImplementedError
    
    return results