import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import models.backbones.timm_vit as vit
import models.backbones.densenet as densenet
import models.backbones.resnet as resnet
import models.backbones.wide_resnet as wide_resnet
import models.ps.layer_warpper as ps

def get_all_leaf_modules(module: nn.Module, layers: list, leaf_layers: list):
    layers.append(module)
    count = 0
    for name, child in module.named_modules():
        count += 1
        if child not in layers:
            get_all_leaf_modules(child, layers, leaf_layers)
    if count == 1:
        leaf_layers.append(module)
        
def resume_pruned_model(post_model: nn.Module, pruned_model: nn.Module):
    # pruned_leaf_layers, post_leaf_layers = [], []
    # get_all_leaf_modules(pruned_model, [], pruned_leaf_layers)
    # get_all_leaf_modules(post_model, [], post_leaf_layers)

    # for l1, l2 in zip(pruned_leaf_layers, post_leaf_layers):
    #     if hasattr(l2, "weight") and l2.weight is not None:
    #         nn.init.constant_(l2.weight, 0.0)
    #         if l1.weight.data.dim() > 1:
    #             h, w, *_ = l1.weight.data.shape
    #             l2.weight.data[:h,:w] = l1.weight.data
    #         else:
    #             h = len(l1.weight.data)
    #             l2.weight.data[:h] = l1.weight.data

    #     if hasattr(l2, 'bias') and l2.bias is not None:
    #         nn.init.constant_(l2.bias, 0.0)
    #         pruned_bias_len = len(l1.bias.data)
    #         l2.bias.data[:pruned_bias_len] = l1.bias.data
    
    param1 = post_model.state_dict()
    param2 = pruned_model.state_dict()
    target_dict = {}
    
    for k in param1:
        source_p = param2[k]
        target_p = param1[k]
        delta_pad = [dim_t - dim_s for dim_t, dim_s in zip(target_p.shape, source_p.shape)][::-1]
        pad = []
        for dim in delta_pad:
            pad.extend([0, dim])
        target_dict[k] = F.pad(source_p, pad, 'constant', 1.0 if "running_var" in k else 0.0)
    post_model.load_state_dict(target_dict)
    return post_model

class FogettingSchedular():
    def __init__(self, forget_rate, forget_fn=lambda x: x * 0.8):
        self.forget_rate = forget_rate
        self.forget_fn = forget_fn
        
    def step(self):
        ret = self.forget_rate
        self.forget_rate = self.forget_fn(self.forget_rate)
        return ret

def accuarcy(output: torch.Tensor, target: torch.Tensor, topk=(1,), mod="cls"):
    assert mod in ["cls"]
    
    if mod == "cls":
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            # print(pred, target)
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
                res.append(correct_k.mul_(100.0 / batch_size))
    return res

def mean(array):
    assert len(array) > 0
    return sum(array) / len(array)

def build_model(args):
    if "vit" in args.backbone:
        target = vit
    elif "wide_resnet" in args.backbone:
        target = wide_resnet
    elif "resnet" in args.backbone:
        target = resnet
    elif "densenet" in args.backbone:
        target = densenet
    return target.build_model(args)

def build_optimizer(args, params) -> optim.Optimizer:
    # if args.wd == 0.0:
    #     args.wd = args.lr / 1000
    if args.optim == "SGD":
        opt = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
    elif args.optim == "ADAM":
        opt = optim.Adam(params, lr=args.lr, weight_decay=args.wd)
    elif args.optim == "ADAMW":
        opt = optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
    return opt

def build_lr_schedular(args, opt, mod="cosine"):
    assert mod in ["cosine", "step"]
    if mod == "step":
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            opt, step_size = args.nEpochs, gamma=0.1)
    elif mod == "cosine":
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=args.nEpochs)
    return lr_scheduler

PS_LAYERS =  [ps.Linear, ps.Conv2d, ps.TConv2d, ps.TLinear]

def convert_ps(model: nn.Module, ignore_layers=[], mask_trainable=False):
    modules = [(n,m) for n,m in model.named_modules()]
    for n, m in modules:
        if type(m) is nn.Linear and m not in ignore_layers:
            conv_type = ps.Linear if not mask_trainable else ps.TLinear
            new_linear = conv_type(m.in_features, m.out_features)
            child_n = n.split('.')[-1]
            if child_n != n:
                father_n = n[:len(n) - len(child_n) - 1]
                father = model.get_submodule(father_n)
                setattr(father, child_n, new_linear)
            else:
                setattr(model, child_n, new_linear)
            new_linear.weight = m.weight
            new_linear.bias = m.bias
        elif type(m) is nn.Conv2d and m not in ignore_layers:
            conv_type = ps.Conv2d if not mask_trainable else ps.TConv2d
            new_conv = conv_type(
                    m.in_channels, 
                    m.out_channels, 
                    kernel_size=m.kernel_size,
                    stride=m.stride,
                    padding=m.padding,      
                    dilation=m.dilation,
                    groups=m.groups,
                    bias=True if m.bias is not None else False,
                    padding_mode=m.padding_mode
                    )
            child_n = n.split('.')[-1]
            if child_n != n:
                father_n = n[:len(n) - len(child_n) - 1]
                father = model.get_submodule(father_n)
                setattr(father, child_n, new_conv)
            else:
                setattr(model, child_n, new_conv)
            new_conv.weight = m.weight
            new_conv.bias = m.bias
            
def ps_load_state_dict(model: nn.Module, path, prefix="fc"):
    stat = torch.load(path)
    stat.pop(prefix + ".weight")
    stat.pop(prefix + ".bias")
    print(model.load_state_dict(stat, strict=False))

def mark_only_ps_as_trainable(model: nn.Module, ignore_layers=[]):
    modules = [(n,m) for n,m in model.named_modules()]
    for n, m in modules:
        if type(m) in PS_LAYERS:
            m.set_trainable_params()
        elif m not in ignore_layers:
            for p in m.parameters():
                p.requires_grad=False
        else:
            for p in m.parameters():
                p.requires_grad = True

def ps_model_init(model: nn.Module, p=1):
    modules = [(n,m) for n,m in model.named_modules()]
    with torch.no_grad():
        for n, m in modules:
            if type(m) in PS_LAYERS:
                m.reset_mask(p=p)
                m.copy_params()

def ps_step(model: nn.Module, p_schedular):
    p = p_schedular.step()
    modules = [(n,m) for n,m in model.named_modules()]
    with torch.no_grad():
        for n, m in modules:
            if type(m) in PS_LAYERS:
                m.compute_channel_importance()
                m.reset_mask(p)
                
def global_ps_step(model, p_schedular):
    p = p_schedular.step()
    global_imps = []
    modules = [(n,m) for n,m in model.named_modules()]
    with torch.no_grad():
        for n, m in modules:
            if type(m) in PS_LAYERS:
                global_imps.append(m.compute_channel_importance())
        global_imps = torch.cat(global_imps, dim=0)
        threshold = global_imps.topk(k=int(p*len(global_imps)))[0][-1]
        for n, m in modules:
            if type(m) in PS_LAYERS:
                m.reset_mask(p, threshold=threshold)
            
class PSchedular():
    def __init__(self, start=1.0, end=0.2, steps=12):
        self.start = start
        self.end = end
        self.steps = steps
        self.p = start
        
    def step(self):
        self.p -= ((self.start - self.end) / self.steps)
        return self.p if self.p > self.end else self.end
                    

def build_p_schedular(args):
    return PSchedular(args.p_start, args.p_end, args.nEpochs // args.p_T - 1)

def ps_visualize(model):
    modules = [(n,m) for n,m in model.named_modules()]
    total = 0.0
    for n, m in modules:
        if type(m) in PS_LAYERS:
            params = sum([p.numel() for p in m.parameters()]) / 2
            if type(m) not in [ps.TConv2d, ps.TLinear]:
                mask = m.ps_mask
                ratio = 1 - sum(mask) / len(mask)
            else: 
                mask = torch.sigmoid(m.ps_mask) > ps.SHARE_THRESHOLD
                ratio = 1 - sum(mask) / len(mask)
            total += params * ratio
            print(f"layer {n} params:{params}, params share ratio: {ratio}")
    print(f"Total shared params: {total}")
    return total
    
def get_TConv_sparse_loss(model, target):
    modules = [(n,m) for n,m in model.named_modules()]
    total_chans = 0.0
    total_shared_chans = []
    vars = []
    for n, m in modules:
        if type(m) in PS_LAYERS:
            shared_channels, var = m.get_sparse_ratio()
            total_chans += len(m.ps_mask)
            total_shared_chans.append(shared_channels)
            vars.append(var)
    vars = sum(vars)
    total_shared_chans = sum(total_shared_chans)
    target_shared_chans = torch.tensor(total_chans * target).to(vars.device)
    ratio_loss = abs((target_shared_chans - total_shared_chans) / total_chans)
    # print(ratio_loss, vars)
    return  ratio_loss - 100 * vars

def fix_TConv_mask(model):
    modules = [(n,m) for n,m in model.named_modules()]
    for n, m in modules:
        if type(m) in PS_LAYERS:
            m.prune_layer()