import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fastargs import get_current_config
from fastargs.decorators import param

get_current_config()


class ConvMask(nn.Conv2d):
    """
    Conv2d layer which inherits from the original PyTorch Conv2d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Conv2d.
    """

    def __init__(self, **kwargs: any) -> None:
        super().__init__(**kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with masked weights."""
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        return F.conv2d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

    def set_er_mask(self, p: float) -> None:
        """
        Method for setting the mask while using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)


class LinearMask(nn.Linear):
    """
    Linear layer which inherits from the original PyTorch Linear layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.

    Args:
        **kwargs: Keyword arguments for nn.Linear.
    """

    def __init__(self, in_features, out_features, bias=True, **kwargs: any) -> None:
        super().__init__(in_features, out_features, bias=True, **kwargs)
        self.register_buffer("mask", torch.ones_like(self.weight))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method for setting the mask while using random pruning.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        return F.linear(x, sparseWeight, self.bias)

    def set_er_mask(self, p: float) -> None:
        """
        Meth setting the mask using random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)


class Conv1dMask(nn.Conv1d):
    """
    Conv1d layer which inherits from the original PyTorch Conv1d layer with an additionally initialized mask parameter.
    This mask is applied during the forward pass to the weights of the layer.
    Used for replacing linear layers with an equivalent 1D convolutional layer.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool): If True, adds a learnable bias to the output. Default: True.
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
        super().__init__(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            bias=bias,
        )
        self.register_buffer("mask", torch.ones_like(self.weight))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with masked weights.
        """
        x = x.unsqueeze(-1)
        sparseWeight = self.mask.to(self.weight.device) * self.weight
        x = F.conv1d(
            x,
            sparseWeight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )
        return x.squeeze(-1)

    def set_er_mask(self, p: float) -> None:
        """
        Enables setting the mask random pruning.

        Args:
            p (float): Probability for Bernoulli distribution.
        """
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

@param('model_params.conv_type')
def replace_layers(conv_type: str, model: nn.Module) -> nn.Module:
    """
    Replaces nn.Linear and nn.Conv2d layers in the model with corresponding masked layers.
    Skips layers which are part of the shortcut connections.

    Args:
        conv_type (str): The type of masked layer to use (e.g., 'ConvMask').
        model (nn.Module): The model in which to replace layers.

    Returns:
        nn.Module: The model with replaced layers.
    """
    layers_to_replace = []

    conv_layer_of_type = globals().get(conv_type)
    print(conv_layer_of_type)
    for name, layer in model.named_modules():
        if "downsample" in name:
            continue
        if isinstance(layer, (nn.Linear, nn.Conv2d)):
            layers_to_replace.append((name, layer))

    for name, layer in layers_to_replace:
        parts = name.split(".")
        parent_module = model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)

        if isinstance(layer, nn.Linear):
            in_features = layer.in_features
            out_features = layer.out_features
            bias = layer.bias is not None

            if conv_type == 'STRConv':
                conv_layer = STRConv1d(in_channels=in_features,
                                out_channels=out_features,
                                kernel_size=1,
                                bias=bias)
            elif conv_type == 'ConvMaskMW':
                conv_layer = MWConv1d(in_channels=in_features,
                                out_channels=out_features,
                                kernel_size=1,
                                bias=bias)
            else:
                conv_layer = Conv1dMask(in_features, out_features, bias)
            setattr(parent_module, parts[-1], conv_layer)
        elif isinstance(layer, nn.Conv2d):
            conv_mask_layer = conv_layer_of_type(
                in_channels=layer.in_channels,
                out_channels=layer.out_channels,
                kernel_size=layer.kernel_size,
                stride=layer.stride,
                padding=layer.padding,
                bias=layer.bias is not None,
            )
            setattr(parent_module, parts[-1], conv_mask_layer)

    print(model)
    return model



@param('model_params.conv_type')
def replace_vit_layers(conv_type: str, model: nn.Module) -> nn.Module:
    """
    For a ViT, replaces the Linear layers with LinearMask

    Args:
        conv_type (str): The type of masked layer to use (e.g., 'LinearMask').
        model (nn.Module): The model in which to replace layers.

    Returns:
        nn.Module: The model with replaced layers.
    """
    layers_to_replace = []

    conv_layer_of_type = globals().get(conv_type)

    for name, layer in model.named_modules():
        if isinstance(layer, (nn.Linear)):
            layers_to_replace.append((name, layer))

    for name, layer in layers_to_replace:
        parts = name.split(".")
        parent_module = model
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)

        if isinstance(layer, nn.Linear):
            in_features = layer.in_features
            out_features = layer.out_features
            bias = layer.bias is not None
            conv_layer = LinearMask(in_features, out_features, bias=bias)
            setattr(parent_module, parts[-1], conv_layer)

    print('Replaced layers with the LinearMask which can be sparsified.')
    print(model)
    return model



def sparseFunction(x, s, activation=torch.relu, f=torch.sigmoid):
    return torch.sign(x)*activation(torch.abs(x)-f(s))

@param('prune_params.str_init_val')
def initialize_sInit(str_init_val):
    return str_init_val*torch.ones([1, 1])

class STRConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.activation = torch.relu

        sparse_function = 'sigmoid'
        if sparse_function == 'sigmoid':
            self.f = torch.sigmoid
            self.sparseThreshold = nn.Parameter(initialize_sInit())
        else:
            self.sparseThreshold = nn.Parameter(initialize_sInit())
    
    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f).to(self.weight.device)
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return (100 - temp.mean().item()*100), temp.numel(), f(self.sparseThreshold).item()

    def get_mask(self, f=torch.sigmoid):
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold,  self.activation, self.f)
        temp = sparseWeight.detach().cpu()
        temp[temp!=0] = 1
        return temp

class STRConv1d(STRConv):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop.
        # However, that should not be the case according to graph computation.
        x = x.unsqueeze(-1).unsqueeze(-1)
        sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f).to(self.weight.device)
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride
        )
        return x.squeeze()



# class ConvMaskMW(nn.Conv2d):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         # This layer aims to reparametrize the mask in order to improve the training and convergence dynamics.

#         # We first fix a separate mask in case of PaI or LRR like methods.
#         self.mask = torch.ones_like(self.weight).to(self.weight.device)
#         self.weight.requires_grad_(False)
        

#         alpha  = 0.5

#         u = torch.sqrt(self.weight.to(self.weight.device) + torch.sqrt(self.weight.to(self.weight.device) * self.weight.to(self.weight.device) + alpha * alpha ))
        
#         # Here, m and w are the trainable parameters. Each parameter \theta = m \cdot w
#         self.m =  nn.Parameter((u + alpha/u)/np.sqrt(2), requires_grad=True)
#         self.w = nn.Parameter((u - alpha/u)/np.sqrt(2), requires_grad=True)  
#         print(torch.mean(self.m * self.w - self.weight))
#         print(torch.mean(self.m * self.m + self.w*self.w))
            
#     def forward(self, x):
       
#         # Here the m, w parameters are learnable
#         sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)

#         x = F.conv2d(
#             x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
#         )
#         return x
    
#     def Merge(self):
#         # This method combines m and w to give an effective weight, such that we can prune based on LRR or IMP
    
#         self.weight.data = self.m.to(self.weight.device) * self.w.to(self.weight.device)

#     def Takeoff(self):
#         # Resetting m and w such that we momentum for sign flips
#         alpha = 0.5 # make schedule
#         #alpha = self.m.to(self.weight.device)*self.m.to(self.weight.device) + self.w.to(self.weight.device) * self.w.to(self.weight.device)/2
#         #alpha_m = torch.mean(alpha)
#         #alpha = torch.min(alpha_m, alpha)
#         #print(alpha)
#         x = self.m.to(self.weight.device) * self.w.to(self.weight.device)

#         u = torch.sqrt(alpha + torch.sqrt(x*x + alpha*alpha))
            
       
#         self.m.data = u#(u + alpha/u)/np.sqrt(2)
#         self.w.data = x/u# (u - alpha/u)/np.sqrt(2)

#     def Landing(self):
#         alpha = 0 # make schedule

#         x = self.m.to(self.weight.device) * self.w.to(self.weight.device)

#         u = torch.sqrt(torch.abs(x))
            
       
#         self.m.data = u
#         self.w.data = torch.sign(x) * u

#         # Resetting m and w such that we get sparser


#     def Flip(self):
#         # Resetting m and w such that we momentum for sign flips
#         #alpha = 0.5 # make schedule
#         alpha = self.m.to(self.weight.device)*self.m.to(self.weight.device) + self.w.to(self.weight.device) * self.w.to(self.weight.device)/2
#         alpha_m = torch.mean(alpha)
#         indicator = alpha_m > alpha
#         #print(alpha)
#         #x = self.m.to(self.weight.device) * self.w.to(self.weight.device)

#         #u = torch.sqrt(alpha + torch.sqrt(x*x + alpha*alpha))
            
       
#         self.m.data = self.m.data * (1 - 2*indicator)#(u + alpha/u)/np.sqrt(2)
#         #self.w.data = x/u# (u - alpha/u)/np.sqrt(2)


#     def set_er_mask(self, p):
#         self.mask = torch.zeros_like(self.weight).bernoulli_(p)

#     def getSparsity(self, f=torch.sigmoid):
        
#         sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)
        
#         temp = sparseWeight.detach().cpu()
#         temp[temp!=0] = 1
#         return (100 - temp.mean().item()*100), temp.numel(), 0


# #class MWConv1d(ConvMaskMW):
# #    def __init__(self, *args, **kwargs):
# #        super().__init__(*args, **kwargs)

# #    def forward(self, x):
# #        x = x.unsqueeze(-1).unsqueeze(-1)
# #        sparseWeight = self.mask.to(self.weight.device) * self.m.to(self.weight.device) * self.w.to(self.weight.device)
# #        x = F.conv2d(
# #            x, sparseWeight, self.bias, self.stride
# #        )
# #        return x.squeeze()


class ConvMaskMW(nn.Conv2d): # using this breaks the dependencies on m.m m.w if depth > 2 these are only used in custom regularization and pruning 
    @param('optimizer.depth')
    @param('optimizer.inbalance')
    def __init__(self, *args, depth=2, alpha=0.5, inbalance=False, **kwargs):
        super().__init__(*args, **kwargs)
        # Mask for pruning
        self.mask = torch.ones_like(self.weight).to(self.weight.device)
        self.weight.requires_grad_(False)

        self.depth = depth
        self.alpha = alpha

        w_init = self.weight.to(self.weight.device)

        # Parameter storage
        self.params = nn.ParameterList()

        if depth == 2:
            # Same init as before
            u = torch.sqrt(w_init + torch.sqrt(w_init * w_init + alpha * alpha))
            m = (u + alpha / u) / np.sqrt(2)
            w = (u - alpha / u) / np.sqrt(2)
            self.params.append(nn.Parameter(m, requires_grad=True))
            self.params.append(nn.Parameter(w, requires_grad=True))
        else:
            if inbalance:
                # First param = original weight
                self.params.append(nn.Parameter(w_init.clone(), requires_grad=True))
                # Rest = ones
                for _ in range(1, depth):
                    self.params.append(nn.Parameter(torch.ones_like(w_init), requires_grad=True))

            else:
                # First param = original weight
                self.params.append(nn.Parameter(torch.sign(w_init) * torch.abs(w_init).pow(1/depth), requires_grad=True))
                # Rest = ones
                for _ in range(1, depth):
                    self.params.append(nn.Parameter(torch.abs(w_init).pow(1/depth), requires_grad=True))


        print("Init weight error:", torch.mean(self.get_effective_weight() - self.weight).item())

    def get_effective_weight(self):
        """Compute θ = Πᵢ pᵢ"""
        eff = self.mask.to(self.weight.device)
        for p in self.params:
            eff = eff * p.to(self.weight.device)
        return eff

    def forward(self, x):
        sparseWeight = self.get_effective_weight()
        return F.conv2d(
            x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )

    def Merge(self):
        """Merge factors into nn.Conv2d.weight"""
        self.weight.data = self.get_effective_weight()

    def Takeoff(self):
        """Reset params (sign flip stabilization, generalized)"""
        x = self.get_effective_weight()
        if self.depth == 2:
            u = torch.sqrt(self.alpha + torch.sqrt(x * x + self.alpha * self.alpha))
            self.params[0].data = u
            self.params[1].data = x / u
        else:
            self.params[0].data = x.clone()
            for i in range(1, self.depth):
                self.params[i].data = torch.ones_like(x)

    def Landing(self):
        """Make params sparser"""
        x = self.get_effective_weight()
        if self.depth == 2:
            u = torch.sqrt(torch.abs(x) + 1e-6)
            self.params[0].data = u
            self.params[1].data = torch.sign(x) * u
        else:
            self.params[0].data = x.clone()
            for i in range(1, self.depth):
                self.params[i].data = torch.ones_like(x)

    def Flip(self):
        """Sign-flip regularization (still depth=2 specific)"""
        if self.depth != 2:
            raise NotImplementedError("Flip is only defined for depth=2.")
        alpha = self.params[0] * self.params[0] + self.params[1] * self.params[1] / 2
        alpha_m = torch.mean(alpha)
        indicator = alpha_m > alpha
        self.params[0].data = self.params[0].data * (1 - 2 * indicator)

    def set_er_mask(self, p):
        self.mask = torch.zeros_like(self.weight).bernoulli_(p)

    def getSparsity(self, f=torch.sigmoid):
        sparseWeight = self.get_effective_weight()
        temp = sparseWeight.detach().cpu()
        temp[temp != 0] = 1
        return (100 - temp.mean().item() * 100), temp.numel(), 0


class MWConv1d(ConvMaskMW):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, x):
        x = x.unsqueeze(-1).unsqueeze(-1)
        sparseWeight = self.get_effective_weight()
        x = F.conv2d(
            x, sparseWeight, self.bias, self.stride
        )
        return x.squeeze()
