import os 
import pickle
import shutil
import numpy as np  
import os.path as osp

import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.nn.utils.prune as prune



def pruning_model(model,px):

    parameters_to_prune =[]
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            parameters_to_prune.append((m,'weight'))

    parameters_to_prune = tuple(parameters_to_prune)

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=px,
    )

def prune_model_custom(model, mask_dict):

    for name,m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            print('pruning layer with custom mask:', name)
            prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'])

def rewind(model, checkpoint_state_dict, prune_flag):

    new_dict = {}
    for name,m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            key_orig = name+'.weight_orig'
            key = name+'.weight'

            if prune_flag:
                out_key = key_orig
            else:
                out_key = key

            if key in checkpoint_state_dict.keys():
                new_dict[out_key] = checkpoint_state_dict[key]
            else:
                new_dict[out_key] = checkpoint_state_dict[key_orig]

    return new_dict

def check_sparsity(model, report=False):
    sum_list = 0
    zero_sum = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            sum_list = sum_list+float(m.weight.nelement())
            zero_sum = zero_sum+float(torch.sum(m.weight == 0))     
    if report:
        print('log zero rate = ', 100*zero_sum/sum_list,'%')
    else:
        print('zero rate = ', 100*zero_sum/sum_list,'%')
    return 100*zero_sum/sum_list

def check_sparsity_mask(model_dict):
    sum_list = 0
    zero_sum = 0
    for key in model_dict.keys():
        sum_list = sum_list+float(model_dict[key].nelement())
        zero_sum = zero_sum+float(torch.sum(model_dict[key] == 0))     

    print('zero_rate = ', 100*zero_sum/sum_list,'%')
    return 100*zero_sum/sum_list

def extract_mask(model_dict):
    mask_weight = {}
    for key in model_dict.keys():
        if 'mask' in key:
            mask_weight[key] = model_dict[key]

    return mask_weight 

def reverse_mask(orig_mask):
    remask = {}
    for key in orig_mask.keys():
        remask[key] = 1-orig_mask[key]

    return remask

def concat_mask(mask1,mask2):

    comask = {}
    for key in mask1.keys():
        comask[key] = mask1[key] + mask2[key]

    return comask

def check_mask(current_mask, model_dict):
    for key in current_mask.keys():
        tensor1 = current_mask[key]
        tensor2 = model_dict[key]
        mul_tensor = tensor1*tensor2
        equal = torch.mean((tensor1 == mul_tensor).float())
        print(key, 'if equal', equal.item())

def extract_weight(model_dict):
    weight_dict={}
    for key in model_dict.keys():
        if 'mask' in key:
            continue
        else:
            weight_dict[key] = model_dict[key]

    return weight_dict

def reverse_rewind(model_dict):
    out_dict = {}
    for key in model_dict.keys():
        if 'orig' in key:
            out_dict[key[:-5]] = model_dict[key]
        else:
            out_dict[key] = model_dict[key]

    return out_dict

def union_mask(mask1,mask2):
    comask = {}
    for key in mask1.keys():
        comask[key] = torch.max(mask1[key], mask2[key])
    return comask

def substract_mask(mask1,mask2):
    submask = {}
    for key in mask1.keys():
        submask[key] = mask1[key] - mask2[key]

    return submask

def remove_model_custom(model):
    print('remove pruning')
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            prune.remove(m,'weight')



# dataset 
class k150_dataset(Dataset):

    def __init__(self, _dir, transform):
        super(k150_dataset, self).__init__()

        self.imgdir=_dir
        self.transforms=transform
        self.all_data = pickle.load(open(self.imgdir,'rb'))
        self.image = self.all_data['data']
        self.label = self.all_data['label']

        self.number = self.image.shape[0]

    def __len__(self):

        return self.number

    def __getitem__(self, index):

        img = self.image[index]
        target = self.label[index]
        img = self.transforms(img)

        return img, target

class Labeled_dataset(Dataset):

    def __init__(self, _dir, transform, target_list, offset=0, num=None):
        super(Labeled_dataset, self).__init__()

        self.imgdir=_dir
        self.transforms=transform
        self.all_image = pickle.load(open(self.imgdir,'rb'))
        self.img = []
        self.target = []

        print('target list = ', target_list)
        for i,idx in enumerate(target_list):
            self.img.append(self.all_image[idx])
            self.target.append((i+offset)*np.ones(self.all_image[idx].shape[0]))

        self.image = np.concatenate(self.img, 0)
        self.label = np.concatenate(self.target, 0)
        self.number = self.image.shape[0]

        if num:
            index = np.random.permutation(self.number)
            select_index = index[:int(num)]
            self.image = self.image[select_index]
            self.label = self.label[select_index]
            self.number = num
        
    def __len__(self):

        return self.number

    def __getitem__(self, index):

        img = self.image[index]
        target = self.label[index]
        img = self.transforms(img)

        return img, target

class unlabel_feature_dataset(Dataset):
    # output: img, soft-logits for random branch, soft-logits for balance branch
    def __init__(self, _dir):
        super(unlabel_feature_dataset, self).__init__()

        self.imgdir = _dir + '_img.npy'
        self.softlogit_dir = _dir + '_dis_label.npy'
        self.softlogit_main_dir = _dir + '_dis_label_main.npy'

        self.image = np.load(self.imgdir)
        self.softlogit = np.load(self.softlogit_dir)
        self.softlogit_main = np.load(self.softlogit_main_dir)

        self.number = self.image.shape[0]

    def __len__(self):

        return self.number

    def __getitem__(self, index):

        img = self.image[index]
        target = self.softlogit[index]
        out_target = self.softlogit_main[index]
    
        img = torch.from_numpy(img)
        target = torch.from_numpy(target)
        out_target = torch.from_numpy(out_target)

        return img, target, out_target

# loss function
def loss_fn_kd(scores, target_scores, T=2.):
    """Compute knowledge-distillation (KD) loss given [scores] and [target_scores].

    Both [scores] and [target_scores] should be tensors, although [target_scores] should be repackaged.
    'Hyperparameter': temperature"""

    device = scores.device

    log_scores_norm = F.log_softmax(scores / T, dim=1)
    targets_norm = F.softmax(target_scores / T, dim=1)

    # if [scores] and [target_scores] do not have equal size, append 0's to [targets_norm]
    if not scores.size(1) == target_scores.size(1):
        print('size does not match')

    n = scores.size(1)
    if n>target_scores.size(1):
        n_batch = scores.size(0)
        zeros_to_add = torch.zeros(n_batch, n-target_scores.size(1))
        zeros_to_add = zeros_to_add.to(device)
        targets_norm = torch.cat([targets_norm.detach(), zeros_to_add], dim=1)

    # Calculate distillation loss (see e.g., Li and Hoiem, 2017)
    KD_loss_unnorm = -(targets_norm * log_scores_norm)
    KD_loss_unnorm = KD_loss_unnorm.sum(dim=1)                      #--> sum over classes
    KD_loss_unnorm = KD_loss_unnorm.mean()                          #--> average over batch

    # normalize
    KD_loss = KD_loss_unnorm * T**2

    return KD_loss

#select unlabel data
def select_knn(target_feature, dataset, number):

    offset = 1e+20

    target_number = target_feature.size(0)
    all_number = dataset.size(0)
    dataset_trans = torch.transpose(dataset, 0, 1)

    target_norm = torch.norm(target_feature, p=2, dim=1).pow(2)
    dataset_norm = torch.norm(dataset, p=2, dim=1).pow(2)

    target = target_norm.repeat(all_number,1)
    all_dis = dataset_norm.repeat(target_number,1)

    distance_matrix = torch.transpose(target, 0, 1) +all_dis - 2*torch.mm(target_feature, dataset_trans)

    print('select')
    select_img = []
    for index in range(target_number):
        distance_one = distance_matrix[index,:]
        nearest = torch.argsort(distance_one)[0:number]
        nearest = nearest.tolist()
        select_img.extend(nearest)
        distance_matrix[:,select_img] = offset

    select_img = list(set(select_img))
    print('selected numbers of unlabel images', len(select_img))
    return select_img

# set random seed
def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def save_checkpoint(state, is_best, save_path, filename='checkpoint.pth.tar', best_name='model_best.pth.tar'):
    filepath = os.path.join(save_path, filename)
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(save_path, best_name))

def label_extract(train_loader, model, criterion, _dir, fc_num):

    img = []
    label = []
    dis_label = []
    dis_label_main = []
    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(train_loader):
        input = input.cuda()
        target = target.long().cuda()

        input_data = {'x': input, 'out_idx': fc_num, 'main_fc': False}
        input_data_main = {'x': input, 'out_idx': fc_num, 'main_fc': True}
        # compute output
        with torch.no_grad():
            output = model(**input_data)
            output_main = model(**input_data_main)

        img.append(input.cpu().numpy())
        label.append(target.cpu().numpy())
        dis_label.append(output.detach().cpu().numpy())
        dis_label_main.append(output_main.detach().cpu().numpy())
        
    img = np.concatenate(img,0)
    label = np.concatenate(label,0)
    dis_label = np.concatenate(dis_label,0)
    dis_label_main = np.concatenate(dis_label_main,0)

    print(img.shape)
    print(dis_label.shape)
    print(dis_label_main.shape)
    print(label.shape)

    np.save(os.path.join(_dir,'task'+str(fc_num)+'_img.npy'),img)
    np.save(os.path.join(_dir,'task'+str(fc_num)+'_dis_label.npy'),dis_label)
    np.save(os.path.join(_dir,'task'+str(fc_num)+'_dis_label_main.npy'),dis_label_main)
    np.save(os.path.join(_dir,'task'+str(fc_num)+'_label.npy'),label)

def feature_extract_old(train_loader, model, criterion):

    losses = AverageMeter()
    top1 = AverageMeter()

    all_feature = []
    # switch to evaluate mode
    model.eval()

    for i, (input, target) in enumerate(train_loader):
        input = input.cuda()
        inputs_data = {'x': input, 'is_feature': True}
        # compute output
        with torch.no_grad():
            feature = model(**inputs_data)

        all_feature.append(feature.cpu())

    all_feature = torch.cat(all_feature, dim=0)
    print('all_feature_size', all_feature.size())

    return all_feature

#logger
class Logger(object):

    def __init__(self, fpath): 
        self.file = None
        if fpath is not None:
            self.file = open(fpath, 'w')

    def append(self, output):
        for index, element in enumerate(output):
            if type(element) == str:
                self.file.write(element)
            else:
                self.file.write("{0:.2f}".format(element))
            self.file.write('\t')
        self.file.write('\n')
        self.file.flush()


