import os

import numpy as np
import torch
from PIL import Image


def clip_tiny(image_tensor):
    """
    adjust the input based on mean and variance, tiny-imagenet
    """
    mean = np.array([0.4802, 0.4481, 0.3975])
    std = np.array([0.2302, 0.2265, 0.2262])

    for c in range(3):
        m, s = mean[c], std[c]
        image_tensor[:, c] = torch.clamp(image_tensor[:, c], -m / s, (1 - m) / s)

    return image_tensor


def denormalize_tiny(image_tensor):
    """
    convert floats back to input, tiny-imagenet
    """
    mean = np.array([0.4802, 0.4481, 0.3975])
    std = np.array([0.2302, 0.2265, 0.2262])

    for c in range(3):
        m, s = mean[c], std[c]
        image_tensor[:, c] = torch.clamp(image_tensor[:, c] * s + m, 0, 1)

    return image_tensor


# class BNFeatureHook:
#     def __init__(self, module):
#         self.hook = module.register_forward_hook(self.hook_fn)

#     def hook_fn(self, module, input, output):
#         nch = input[0].shape[1]
#         mean = input[0].mean([0, 2, 3])
#         var = input[0].permute(1, 0, 2, 3).contiguous().reshape([nch, -1]).var(1, unbiased=False)
#         r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(module.running_mean.data - mean, 2)
#         self.r_feature = r_feature

#     def close(self):
#         self.hook.remove()

class BNFeatureHookLT:
    def __init__(self, module, save_path, training_momentum, name, class_number, alpha):
        self.module = module
        if module is not None and name is not None:
            self.hook = module.register_forward_hook(self.post_hook_fn)
        else:
            raise ModuleNotFoundError("Module or name is None")

        self.dd_var = 0
        self.dd_mean = 0
        self.momentum = training_momentum
        self.bn_statis_list = []
        self.class_number = class_number
        self.alpha = alpha

        self.category_save_path_list = []
        for i in range(self.class_number):
            cls_dir = os.path.join(save_path, f"BNFeatureHook", f"class_{i}", name)
            if not os.path.exists(cls_dir):
                os.makedirs(cls_dir, exist_ok=True)
            self.category_save_path_list.append(os.path.join(cls_dir, "running.npz"))
        
        self.category_running_dd_var_list = [0. for i in range(self.class_number)]
        self.category_running_dd_mean_list = [0. for i in range(self.class_number)]
        self.load_tag = True
        for i, category_save_path in enumerate(self.category_save_path_list):
            if os.path.exists(category_save_path):
                npz_file = np.load(category_save_path)
                self.load_tag = True & self.load_tag
                self.category_running_dd_var_list[i] = torch.from_numpy(npz_file["running_dd_var"]).cuda()
                self.category_running_dd_mean_list[i] = torch.from_numpy(npz_file["running_dd_mean"]).cuda()
            else:
                self.load_tag = False
                self.counter = [0 for i in range(100)]
        if self.load_tag:
            self.category_running_dd_mean_list = torch.stack(self.category_running_dd_mean_list, 0)
            self.category_running_dd_var_list = torch.stack(self.category_running_dd_var_list, 0)

    def set_label(self, targets):
        self.targets = targets
    
    def set_hook(self, pre=True):
        if hasattr(self, "hook"):
            self.close()
        if pre:
            self.hook = self.module.register_forward_hook(self.pre_hook_fn)
        else:
            self.hook = self.module.register_forward_hook(self.post_hook_fn)
    
    def save(self):
        for i, category_save_path in enumerate(self.category_save_path_list):
            self.category_running_dd_mean_list[i] = self.category_running_dd_mean_list[i]/self.counter[i]
            self.category_running_dd_var_list[i] = self.category_running_dd_var_list[i]/self.counter[i]
            npz_file = {"running_dd_var": self.category_running_dd_var_list[i].cpu().numpy() if isinstance(self.category_running_dd_var_list[i],
                                                                                      torch.Tensor) else self.category_running_dd_var_list[i],
                        "running_dd_mean": self.category_running_dd_mean_list[i].cpu().numpy() if isinstance(self.category_running_dd_mean_list[i],
                                                                                        torch.Tensor) else self.category_running_dd_mean_list[i]}
            np.savez(category_save_path, **npz_file)
            
            npz_file = np.load(category_save_path)
            self.category_running_dd_var_list[i] = torch.from_numpy(npz_file["running_dd_var"]).cuda()
            self.category_running_dd_mean_list[i] = torch.from_numpy(npz_file["running_dd_mean"]).cuda()
        if not isinstance(self.category_running_dd_mean_list, torch.Tensor):
            self.category_running_dd_mean_list = torch.stack(self.category_running_dd_mean_list, 0)
        if not isinstance(self.category_running_dd_var_list, torch.Tensor):
            self.category_running_dd_var_list = torch.stack(self.category_running_dd_var_list, 0)

    
    @torch.no_grad()
    def pre_hook_fn(self, module, input, output):
        nch = input[0].shape[1]
        bs = input[0].shape[0]
        input_0 = input[0]
        dd_mean = input_0.mean([2, 3])
        dd_var = input_0.view(bs, nch, -1).var(2, unbiased=False)
        for i in range(bs):
            c_m, c_v, cls = dd_mean[i], dd_var[i], self.targets[i].int().item()
            self.category_running_dd_mean_list[cls] += c_m
            self.category_running_dd_var_list[cls] += c_v
            self.counter[cls] += 1
    
    def post_hook_fn(self, module, input, output):
        nch = input[0].shape[1]
        bs = input[0].shape[0]
        input_0 = input[0]
        mean = input_0.mean([0, 2, 3])
        var = (input_0.permute(1, 0, 2, 3).contiguous().reshape([nch, -1])).var(1, unbiased=False)
        with torch.no_grad():
            if isinstance(self.dd_var, int):
                self.dd_var = var
                self.dd_mean = mean
            else:
                self.dd_var = self.momentum * self.dd_var + (1 - self.momentum) * var
                self.dd_mean = self.momentum * self.dd_mean + (1 - self.momentum) * mean
        r_feature = (torch.norm(module.running_var.data - (self.dd_var + var - var.detach()), 2) + \
                    torch.norm(module.running_mean.data - (self.dd_mean + mean - mean.detach()), 2)) * (1-self.alpha)
        category_dd_var = self.category_running_dd_var_list[self.targets.long()].mean(0)
        category_dd_mean = self.category_running_dd_mean_list[self.targets.long()].mean(0)
        r_feature += (torch.norm(category_dd_var - (self.dd_var + var - var.detach()), 2) + \
                    torch.norm(category_dd_mean - (self.dd_mean + mean - mean.detach()), 2)) * self.alpha
        self.r_feature = r_feature

    def close(self):
        self.hook.remove()


def lr_policy(lr_fn):
    def _alr(optimizer, iteration, epoch):
        lr = lr_fn(iteration, epoch)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr

    return _alr


def lr_cosine_policy(base_lr, warmup_length, epochs):
    def _lr_fn(iteration, epoch):
        if epoch < warmup_length:
            lr = base_lr * (epoch + 1) / warmup_length
        else:
            e = epoch - warmup_length
            es = epochs - warmup_length
            lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
        return lr

    return lr_policy(_lr_fn)


def save_images(args, images, targets, ipc_id):
    for id in range(images.shape[0]):
        if targets.ndimension() == 1:
            class_id = targets[id].item()
        else:
            class_id = targets[id].argmax().item()

        if not os.path.exists(args.syn_data_path):
            os.mkdir(args.syn_data_path)

        # save into separate folders
        dir_path = '{}/new{:03d}'.format(args.syn_data_path, class_id)
        place_to_store = dir_path + '/class{:03d}_id{:03d}.jpg'.format(class_id, ipc_id)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)

        image_np = images[id].data.cpu().numpy().transpose((1, 2, 0))
        pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
        pil_image.save(place_to_store)


def validate(input, target, model):
    def accuracy(output, target, topk=(1,)):
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

    with torch.no_grad():
        output = model(input)
        prec1, prec5 = accuracy(output.data, target, topk=(1, 5))

    print("Verifier accuracy: ", prec1.item())


def compute_adjustment(train_loader, tro, device):
    label_freq = {}
    for i, (inputs, target) in enumerate(train_loader):
        target = target.to(device)
        for j in target:
            key = int(j.item())
            label_freq[key] = label_freq.get(key, 0) + 1
    label_freq = dict(sorted(label_freq.items()))
    label_freq_array = np.array(list(label_freq.values()))
    label_freq_array = label_freq_array / label_freq_array.sum()
    adjustments = np.log(label_freq_array ** tro + 1e-12)
    adjustments = torch.from_numpy(adjustments)
    adjustments = adjustments.to(device)
    return adjustments

class AverageMeter:
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count