import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.distributed as dist

from fastargs import get_current_config
from fastargs.decorators import param

get_current_config()

# This function aims to mask the forward pass but maintain dense gradients on the backward
class MaskDenseGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, weight, mask):
        # Forward pass: multiply input by the constant
        return weight * mask

    @staticmethod
    def backward(ctx, grad_output): 
        return grad_output, None
    

class ConvMask(nn.Conv2d):
    """
    Conv2d layer which inherits from the original PyTorch Conv2d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Conv2d.
    """
    @param("experiment_params.dense_grad")
    def __init__(self, dense_grad: bool, **kwargs: any) -> None:
        super().__init__(**kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad
        print(f"the layer will evaluate dense gradients: {dense_grad}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with masked weights."""
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.weight, self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.weight
        return F.conv2d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

    def set_er_mask(self, p: float) -> None:
        """
        Method for setting the mask while using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)


class LinearMask(nn.Linear):
    """
    Linear layer which inherits from the original PyTorch Linear layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Linear.
    """

    @param("experiment_params.dense_grad")
    def __init__(self, in_features, out_features, dense_grad, bias=True, **kwargs: any) -> None:
        super().__init__(in_features, out_features, bias=True, **kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad
        print(f"the layer will evaluate dense gradients: {dense_grad}")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method for setting the mask while using random pruning.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.weight, self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.weight
        return F.linear(x, sparseWeight, self.bias)

    def set_er_mask(self, p: float) -> None:
        """
        Meth setting the mask using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)


class Conv1dMask(nn.Conv1d):
    """
    Conv1d layer which inherits from the original PyTorch Conv1d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.
    Used for replacing linear layers with an equivalent 1D convolutional layer.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool): If True, adds a learnable bias to the output. Default: True.
    """

    @param("experiment_params.dense_grad")
    def __init__(self, in_features: int, out_features: int, bias: bool=False, dense_grad=False) -> None:
        super().__init__(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            bias=bias,
        )
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad
        print(f"the layer will evaluate dense gradients: {dense_grad}")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with masked weights.
        """
        x = x.unsqueeze(-1)
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.weight, self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.weight
        
        x = F.conv1d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x.squeeze(-1)

    def set_er_mask(self, p: float) -> None:
        """
        Enables setting the mask random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

@param('model_params.conv_type')
def replace_layers(conv_type: str, model: nn.Module) -> nn.Module:
    """
    Replaces nn.Linear and nn.Conv2d layers in the model with corresponding masked layers.
    Skips layers which are part of the shortcut connections.

    Args:
        conv_type (str): The type of masked layer to use (e.g., 'ConvMask').
        model (nn.Module): The model in which to replace layers.

    Returns:
        nn.Module: The model with replaced layers.
    """
    layers_to_replace = []

    conv_layer_of_type = globals().get(conv_type)
    print(conv_layer_of_type)
    for name, layer in model.named_modules():
        if "downsample" in name:
            continue
        if isinstance(layer, (nn.Linear, nn.Conv2d)):
            layers_to_replace.append((name, layer))

    for name, layer in layers_to_replace:
        parts = name.split(".")
        parent_module = model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)

        if isinstance(layer, nn.Linear):
            in_features = layer.in_features
            out_features = layer.out_features
            bias = layer.bias is not None

            if conv_type == 'STRConv':
                conv_layer = STRConv1d(in_channels=in_features,
                                out_channels=out_features,
                                kernel_size=1,
                                bias=bias)
            elif conv_type == 'ConvMaskMW':
                conv_layer = Conv1dMaskMW(in_features, out_features, bias)
            else:
                conv_layer = Conv1dMask(in_features, out_features, bias)
            setattr(parent_module, parts[-1], conv_layer)
        elif isinstance(layer, nn.Conv2d):
            conv_mask_layer = conv_layer_of_type(
                in_channels=layer.in_channels,
                out_channels=layer.out_channels,
                kernel_size=layer.kernel_size,
                stride=layer.stride,
                padding=layer.padding,
                bias=layer.bias is not None,
            )
            setattr(parent_module, parts[-1], conv_mask_layer)

    print(model)
    return model



@param('model_params.conv_type')
def replace_vit_layers(conv_type: str, model: nn.Module) -> nn.Module:
    """
    For a ViT, replaces the Linear layers with LinearMask

    Args:
        conv_type (str): The type of masked layer to use (e.g., 'LinearMask').
        model (nn.Module): The model in which to replace layers.

    Returns:
        nn.Module: The model with replaced layers.
    """
    layers_to_replace = []

    linear_layer_of_type = globals().get(conv_type)

    # This gets the pretrained weights in case pretrained=True
    curr_state_dict = model.state_dict()

    for name, layer in model.named_modules():
        if isinstance(layer, (nn.Linear)):
            layers_to_replace.append((name, layer))

    for name, layer in layers_to_replace:
        parts = name.split(".")
        parent_module = model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)

        if isinstance(layer, nn.Linear):
            in_features = layer.in_features
            out_features = layer.out_features
            bias = layer.bias is not None
            
            conv_layer = linear_layer_of_type(in_features, out_features, bias=bias)
            setattr(parent_module, parts[-1], conv_layer)
            
    model_state_dict = model.state_dict()
    model_state_dict.update(curr_state_dict)
    model.load_state_dict(model_state_dict)

    # initialize mw from pretrained weight appropriately
    if conv_type == 'LinearMaskMW':
        print('Loading pretrained weights and splitting into m and w')
        for n, m in model.named_modules():
            if isinstance(m, (LinearMaskMW)):
                m.init_from_weight()
    
    print('Replaced layers with the Masks which can be sparsified.')
    print(model)
    return model



def sparseFunction(x, s, activation=torch.relu, f=torch.sigmoid):
    return torch.sign(x)*activation(torch.abs(x)-f(s))

@param('prune_params.str_init_val')
def initialize_sInit(str_init_val):
    return str_init_val*torch.ones([1, 1])

class STRConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu

        sparse_function = 'sigmoid'
        if sparse_function == 'sigmoid':
            self.f = torch.sigmoid
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        else:
            self.sparseThreshold = nn.Parameter(initialize_sInit())
    
    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f).to(self.weight.device)
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), f(self.sparseThreshold).item()

    def get_mask(self, f=torch.sigmoid):
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1 
        return temp

class STRConv1d(STRConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        x = x.unsqueeze(-1).unsqueeze(-1)
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f).to(self.weight.device)
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride
        )
        return x.squeeze()
    

class ConvMaskMW(ConvMask):
    """
    Conv2d layer which inherits from the ConvMask layer with a parametrization of m*w for the weights.
    Additionally a mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Conv2d.
    """
    @param("experiment_params.dense_grad")
    def __init__(self, dense_grad: bool, **kwargs: any) -> None:
        super().__init__(**kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad

        alpha = 0.5
        u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha))
        
        self.m =  nn.Parameter((u + alpha/u)/np.sqrt(2), requires_grad=True)        
        self.weight.data = (u - alpha/u)/np.sqrt(2)

        print(f"the layer will evaluate dense gradients: {dense_grad}")

    def merge(self):
        # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
        self.weight.data = self.m.to(self.weight.device) * self.weight

    def rescale_mw(self):
        # Resetting m and w such that we momentum for sign flips
        alpha = 0.5 # make schedule
        # alpha = self.m.to(self.weight.device)*self.m.to(self.weight.device) + self.w.to(self.weight.device) * self.w.to(self.weight.device)/2
        # alpha_m = torch.mean(alpha)
        # # Average across devices
        # dist.all_reduce(alpha_m, op=dist.ReduceOp.SUM)
        # alpha_m = alpha_m / dist.get_world_size()

        # alpha = torch.min(alpha_m, alpha)

        x = self.m.to(self.weight.device) * self.weight.to(self.weight.device)

        u = torch.sqrt(alpha + torch.sqrt(x*x + alpha*alpha))
    
        self.m.data = u #(u + alpha/u)/np.sqrt(2)
        self.weight.data = x/u # (u - alpha/u)/np.sqrt(2)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with masked weights."""
        
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.m.to(self.weight.device) * self.weight.to(self.weight.device), self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.weight.to(self.weight.device)
        return F.conv2d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Conv1dMaskMW(Conv1dMask):
    """
    Conv1d layer which inherits from the original PyTorch Conv1d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.
    Used for replacing linear layers with an equivalent 1D convolutional layer.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool): If True, adds a learnable bias to the output. Default: True.
    """

    @param("experiment_params.dense_grad")
    def __init__(self, in_features: int, out_features: int, bias: bool=False, dense_grad=False) -> None:
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            dense_grad=dense_grad
            )

        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad

        alpha = 0.5
        u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha))
        
        self.m =  nn.Parameter((u + alpha/u)/np.sqrt(2), requires_grad=True)        
        self.weight.data = (u - alpha/u)/np.sqrt(2)
        print(f"the layer will evaluate dense gradients: {dense_grad}")

    def rescale_mw(self):
        # Resetting m and w such that we momentum for sign flips
        alpha = 0.5 # make schedule
        # alpha = self.m.to(self.weight.device)*self.m.to(self.weight.device) + self.w.to(self.weight.device) * self.w.to(self.weight.device)/2
        # alpha_m = torch.mean(alpha)
        # # Average across devices
        # dist.all_reduce(alpha_m, op=dist.ReduceOp.SUM)
        # alpha_m = alpha_m / dist.get_world_size()

        # alpha = torch.min(alpha_m, alpha)

        x = self.m.to(self.weight.device) * self.weight.to(self.weight.device)

        u = torch.sqrt(alpha + torch.sqrt(x*x + alpha*alpha))
    
        self.m.data = u #(u + alpha/u)/np.sqrt(2)
        self.weight.data = x/u # (u - alpha/u)/np.sqrt(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with masked weights.
        """
        x = x.unsqueeze(-1)
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.m.to(self.weight.device) * self.weight.to(self.weight.device), self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.weight.to(self.weight.device)
        
        x = F.conv1d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x.squeeze(-1)

    def set_er_mask(self, p: float) -> None:
        """
        Enables setting the mask random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)


class LinearMaskMW(LinearMask):
    """
    Linear layer which inherits from the original PyTorch Linear layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Linear.
    """

    @param("experiment_params.dense_grad")
    def __init__(self, in_features, out_features, dense_grad, bias=True, **kwargs: any) -> None:
        super().__init__(in_features, out_features, bias=True, **kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad
        print(f"the layer will evaluate dense gradients: {dense_grad}")
        alpha = 0.5
        u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha))
        
        self.m =  nn.Parameter((u + alpha/u)/np.sqrt(2), requires_grad=True)        
        self.weight.data = (u - alpha/u)/np.sqrt(2)

    def init_from_weight(self):
        # initializing mw after loading pretrained weights
        alpha = 0.5
        u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device)*self.weight.to(self.weight.device) + alpha*alpha))
        
        self.m =  nn.Parameter((u + alpha/u)/np.sqrt(2), requires_grad=True)        
        self.weight.data = (u - alpha/u)/np.sqrt(2)


    def rescale_mw(self):
        # Resetting m and w such that we momentum for sign flips
        alpha = 0.5 # make schedule
        # alpha = self.m.to(self.weight.device)*self.m.to(self.weight.device) + self.w.to(self.weight.device) * self.w.to(self.weight.device)/2
        # alpha_m = torch.mean(alpha)
        # # Average across devices
        # dist.all_reduce(alpha_m, op=dist.ReduceOp.SUM)
        # alpha_m = alpha_m / dist.get_world_size()

        # alpha = torch.min(alpha_m, alpha)

        x = self.m.to(self.weight.device) * self.weight.to(self.weight.device)

        u = torch.sqrt(alpha + torch.sqrt(x*x + alpha*alpha))
    
        self.m.data = u #(u + alpha/u)/np.sqrt(2)
        self.weight.data = x/u # (u - alpha/u)/np.sqrt(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method for setting the mask while using random pruning.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.m.to(self.weight.device) * self.weight.to(self.weight.device), self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.weight.to(self.weight.device)
        
        return F.linear(x, sparseWeight, self.bias)

    def set_er_mask(self, p: float) -> None:
        """
        Meth setting the mask using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

def MP_prob(N,p):
    b = (1+np.sqrt(p))**2
    a = (1-np.sqrt(p))**2
    x = np.linspace(a+0.00001,b,N)
    dd =np.sqrt((x-a)*(b-x))/(x*2*np.pi*p)
    return dd/np.sum(dd)

class ConvMaskScaled(ConvMask):
    """
    Conv2d layer which inherits from the original PyTorch Conv2d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Conv2d.
    """
    @param("experiment_params.dense_grad")
    @param("prune_params.er_init")
    def __init__(self, dense_grad: bool, er_init: float, **kwargs: any) -> None:
        super().__init__(**kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))
        self.dense_grad = dense_grad
        print(f"the layer will evaluate dense gradients: {dense_grad}")
        self.set_er_mask(er_init)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with masked weights."""
        if self.dense_grad:
            sparseWeight = MaskDenseGrad.apply(self.weight, self.mask.to(self.weight.device))
        else:
            sparseWeight = self.mask.to(self.weight.device) * self.weight
        return F.conv2d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

    def set_er_mask(self, p: float) -> None:
        """
        Method for setting the mask while using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        b = (1+np.sqrt(p))**2
        a = (1-np.sqrt(p))**2
        dd = torch.tensor(MP_prob(10000,p))
        #dd = torch.multinomial(torch.tensor(MP_prob(10001,p)), num_samples=weight.numel(), replacement=True)*0.000025
        #dd = torch.sqrt(torch.multinomial(dd, num_samples=weight.numel(), replacement=True)*((b-a-0.00001)/10000)+a+0.00001)
        dd = torch.multinomial(dd, num_samples=self.weight.numel(), replacement=True)*((b-a-0.00001)/10000)+a+0.00001
        self.mask = dd.reshape(self.weight.size()) #torch.zeros_like(self.weight).bernoulli_(p)
