import torch
import torch.nn as nn
import logging

from timm.layers import Mlp

_logger = logging.getLogger(__name__)


class AdaptOp(nn.Module):

    def __init__(self, dim, downscale, bias, residual=True, small_init=False, init_scale=1e-6, act_layer=nn.GELU, **kwargs):

        super().__init__()

        self.dim = dim
        self.residual = residual
        self.downscale = downscale
        self.use_bias = bias
        self.init_scale = init_scale

        hidden_dim = int(dim / downscale)
        self.op = nn.Sequential(*[
            nn.Linear(dim, hidden_dim, bias=bias), 
            act_layer(), 
            nn.Linear(hidden_dim, dim, bias=bias)
        ])

        _logger.info(f"MLP Adapt residual: {residual}")

        if small_init:
            self.reset_parameters()

    def reset_parameters(self):
        for layer in self.op:
            if isinstance(layer, nn.Linear):
                _logger.info(f"Resetting parameters for {layer} with mean=0, scale={self.init_scale}")
                nn.init.normal_(layer.weight, mean=0., std=self.init_scale)
                if self.use_bias:
                    nn.init.zeros_(layer.bias)

    def forward(self, parent_input, parent_output, **kwargs):

        if self.residual:
            return parent_output + self.op(parent_output)
        else:
            return self.op(parent_output)
        
    def __repr__(self):
        return f"AdaptOp(dim={self.dim}, downscale={self.downscale}, bias={self.use_bias}, residual={self.residual})"


class LinearNewOp(nn.Module):

    def __init__(self, in_features, out_features, bias, init_mean_weights=None, init_mean_bias=None, init_scale=None, **kwargs):

        super().__init__()
        
        self.op = nn.Linear(in_features, out_features, bias=bias)

        if init_mean_weights is not None and init_scale is not None:
            if bias:
                assert init_mean_bias is not None
            self.reset_parameters(init_mean_weights, init_mean_bias, init_scale)

    def reset_parameters(self, init_mean_weights: torch.Tensor, init_mean_bias: torch.Tensor, init_scale: float):
        _logger.info(f"Resetting parameters for {self.op}. Noise Scale={init_scale}")
        # nn.init.normal_(self.op.weight, mean=0., std=init_scale)
        with torch.no_grad():
            # self.op.weight.add_(init_mean_weights)
            self.op.weight.copy_(init_mean_weights)

        if self.op.bias is not None:
            # nn.init.normal_(self.op.bias, mean=0., std=init_scale)
            with torch.no_grad():
                # self.op.bias.add_(init_mean_bias)
                self.op.bias.copy_(init_mean_bias)

    def forward(self, x, **kwargs):

        return self.op(x)
    
    def __repr__(self):
        return f"LinearNewOp({str(self.op)})"


class SkipOp(nn.Module):

    def __init__(self, in_features, out_features, **kwargs):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

    def forward(self, x: torch.Tensor, **kwargs):
        if x.ndim == 2:
            # Plain MLP
            B, C = x.shape
            assert C == self.in_features
            if self.out_features == C:
                output = x.mul_(0.)
            else:
                output = torch.zeros(B, self.out_features, device=x.device, requires_grad=False)
        elif x.ndim == 3:
            # Trasformer style
            B, N, C = x.shape
            assert C == self.in_features
            if self.out_features == C:
                output = x.mul_(0.)
            else:
                output = torch.zeros(B, N, self.out_features, device=x.device, requires_grad=False)
        else:
            raise ValueError(f"Number of dimensions {x.ndim} not supported.")
        return output


class MlpNewOp(nn.Module):

    def __init__(self, in_features, hidden_features, out_features, bias, init_mean_weights=None, init_mean_bias=None, init_scale=None, act_layer=nn.GELU, **kwargs):

        super().__init__()
        
        self.op = Mlp(
            in_features,
            hidden_features=None,
            out_features=out_features,
            act_layer=act_layer,
            bias=bias
        )

        if init_mean_weights is not None and init_scale is not None:
            if bias:
                assert init_mean_bias is not None
            self.reset_parameters(init_mean_weights, init_mean_bias, init_scale)

    def reset_parameters(self, init_mean_weights: torch.Tensor, init_mean_bias: torch.Tensor, init_scale: float):
        _logger.info(f"Resetting parameters for {self.op}. Noise Scale={init_scale}")
        with torch.no_grad():
            self.op.fc1.weight.copy_(init_mean_weights["fc1"])
            self.op.fc2.weight.copy_(init_mean_weights["fc2"])

        if self.op.bias is not None:
            # nn.init.normal_(self.op.bias, mean=0., std=init_scale)
            with torch.no_grad():
                self.op.fc1.bias.copy_(init_mean_bias["fc1"])
                self.op.fc2.bias.copy_(init_mean_bias["fc2"])

    def forward(self, x, **kwargs):

        return self.op(x)
    
    def __repr__(self):
        return f"MlpNewOp({str(self.op)})"
