from functools import partial


import torch
import torch.nn as nn
from torch.optim import SGD,Adam,AdamW, lr_scheduler
from torch.nn import NLLLoss,MSELoss, CrossEntropyLoss

CHANNEL_WISE_LAYERS=(torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.BatchNorm1d,torch.nn.MaxPool2d,torch.nn.AdaptiveAvgPool2d,torch.nn.Flatten,torch.nn.Dropout)
CROSS_CHANNEL_LAYERS=(torch.nn.Linear,torch.nn.Conv2d)

NORMALIZE_LAYERS=(torch.nn.modules.batchnorm._BatchNorm,torch.nn.modules.normalization.LayerNorm,torch.nn.GroupNorm,torch.nn.modules.instancenorm._InstanceNorm)

# abuse of name "semilinear". f(x+b) = f(x)+b
SEMILINEAR_LAYERS=(torch.nn.MaxPool2d,torch.nn.AdaptiveAvgPool2d)

# f(x+b) = f(x)+f(b)
ADDITIVE_LAYERS = (torch.nn.Flatten,torch.nn.Dropout)

# f_{a}(x+b) = f_{a'}(x)
NORMALIZE_LAYERS = (torch.nn.BatchNorm2d,torch.nn.BatchNorm1d)

AVAILABLE_OPTIMIZERS={
    'sgd':SGD,
    'adam':Adam,
    'adamw':AdamW
}
OPTIM_STATE_KEYS={
    'sgd':['momentum_buffer'],
    'adam':['step','exp_avg','exp_avg_sq'],
    'adamw':['step','exp_avg','exp_avg_sq']
}
AVAILABLE_LOSSES={
    'categorical_crossentropy':CrossEntropyLoss,
    'mse':MSELoss
}
AVAILABLE_SCHEDULERS={
    'multi_step': lr_scheduler.MultiStepLR,
    'cosine': lr_scheduler.CosineAnnealingLR,
    'cos_restart':lr_scheduler.CosineAnnealingWarmRestarts,
    'linear':lr_scheduler.LinearLR
}

NORM_FUNCTIONS={
    'l1':lambda x, dim : torch.linalg.norm(x,ord=1,dim=dim),
    'l1mean':lambda x, dim : torch.mean(abs(x),dim = dim),
    'l2mean':lambda x, dim : torch.sqrt(torch.mean(x**2, dim = dim)),
    'l2':lambda x, dim : torch.linalg.norm(x,ord=2,dim=dim),
    'mse':lambda x, dim: torch.mean(x**2,dim=dim),
    'l21mean':lambda x, dim : torch.mean(torch.linalg.norm(x,ord=1),dim=dim),
    'l21':lambda x, dim : torch.sum(torch.linalg.vector_norm(x,dim=dim))
}

def name_id_dict(model:torch.nn.Module,optim:torch.optim.Optimizer):
    param_mappings = {}
    name2id={}
    id2name={}
    start_index = 0
    for g in optim.param_groups:
        for i,p in enumerate(g['params'],start_index):
            if id(p) not in param_mappings:
                param_mappings[id(p)] = i
        start_index+=len(g['params'])
    for name, p in model.named_parameters():
        name2id[name] = param_mappings[id(p)]
        id2name[param_mappings[id(p)]]=name
    return name2id, id2name

def get_norm_vector(tensor:torch.Tensor,axis=0,p=1,mean:bool=False):
    if tensor.dim() == 1:
        return abs(tensor)
    else:
        num_channels =  tensor.shape[axis]
        ret=tensor.transpose(0,axis).reshape(num_channels,-1)
        num_elements = ret.shape[1]
        return ret.norm(dim=-1,p=p)
    

def _cross_channel_layer_update(layer,A_mat,bias_vector):
    passed_vec = torch.einsum('ij...->ij',A_mat) @ bias_vector  
    if layer.bias is not None:
        layer.bias.data = layer.bias.data + passed_vec
    else:
        layer.bias = torch.nn.Parameter(data = passed_vec, requires_grad=True)
    return None
def pass_layer(bias_vector,layer, prune_indices=None):
    '''
    bias_vector: current residual vector (The W.bias which passed layers between W and current layer).
    layer: current layer

    returns: if curren
    '''
    if bias_vector is None: # skipping sign
        return None
    if prune_indices is None:
        return pass_layer_bypass(bias_vector,layer)
    if isinstance(layer, CROSS_CHANNEL_LAYERS): # A layer
        A_mat = torch.index_select(layer.weight,1,prune_indices)
        return _cross_channel_layer_update(layer,A_mat,bias_vector.index_select(0,prune_indices))
    if isinstance(layer,SEMILINEAR_LAYERS):
        return bias_vector
    if isinstance(layer,NORMALIZE_LAYERS):
        return layer(bias_vector.view(1,-1,1,1)).squeeze()
    if isinstance(layer,CHANNEL_WISE_LAYERS):
        return layer(bias_vector.unsqueeze(0)).squeeze(0)
        
def pass_layer_bypass(bias_vector,layer):
    
    if isinstance(layer, CROSS_CHANNEL_LAYERS): # A layer
        A_mat = layer.weight
        return _cross_channel_layer_update(layer,A_mat,bias_vector)
    if isinstance(layer,SEMILINEAR_LAYERS):
        return bias_vector
    if isinstance(layer,ADDITIVE_LAYERS):
        return layer(bias_vector.unsqueeze(0)).squeeze(0)
    if isinstance(layer,NORMALIZE_LAYERS):
        if not hasattr(layer,'running_mean'): # bias vector has no effect
            return None
        # if prune_indices is not None:
        #     inverse_mask = torch.ones_like(layer.running_mean)
        #     inverse_mask[prune_indices] = 0
        #     layer.running_mean.data[inverse_mask] += bias_vector
        #     return None
        else:
            layer.running_mean.data  += bias_vector
            return None

    raise TypeError(f'layer {layer.__class__} is not supported')

class CrossEntropyLabelSmooth(nn.Module):

  def __init__(self, num_classes, epsilon):
    super(CrossEntropyLabelSmooth, self).__init__()
    self.num_classes = num_classes
    self.epsilon = epsilon
    self.logsoftmax = nn.LogSoftmax(dim=1)

  def forward(self, inputs, targets):
    log_probs = self.logsoftmax(inputs)
    targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
    targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
    loss = (-targets * log_probs).mean(0).sum()
    return loss
        
    