import torch.nn as nn
from typing import List, Union
import torch
from torchtyping import TensorType
import numpy as np
from pydantic import BaseModel

from ..utils.getting_modules import get_module_by_name
from ..utils.setting_attr import setattr_pytorch_model
import wandb

from typing import Union


class BrainInspiredLayer(nn.Module):
    def __init__(
        self,
        layer: Union[nn.Conv2d, nn.Linear],
        in_fold: int = 1,
        out_fold: int = 1,
    ):
        super().__init__()

        assert isinstance(
            layer, Union[nn.Conv2d, nn.Linear]
        ), f"Expected either nn.Conv2d or nn.Linear but got: {type(layer)}"
        self.layer = layer

        self.in_fold = in_fold
        self.out_fold = out_fold

        self.in_features = self.weight.shape[1]
        self.out_features = self.weight.shape[0]

        in_features_fold = int(self.in_features / self.in_fold)
        out_features_fold = int(self.out_features / self.out_fold)
        self.in_coordinates = torch.tensor(
            list(
                np.linspace(
                    1 / (2 * in_features_fold),
                    1 - 1 / (2 * in_features_fold),
                    num=in_features_fold,
                )
            )
            * in_fold,
            dtype=torch.float,
        )
        self.out_coordinates = torch.tensor(
            list(
                np.linspace(
                    1 / (2 * out_features_fold),
                    1 - 1 / (2 * out_features_fold),
                    num=out_features_fold,
                )
            )
            * out_fold,
            dtype=torch.float,
        )

    @property
    def weight(self):
        if isinstance(self.layer, nn.Linear):
            return self.layer.weight
        elif isinstance(self.layer, nn.Conv2d):
            weight = self.layer.weight
            """
            The torch.norm() function calculates the norm along the last two dimensions (dim=(-2, -1)) of the weight tensor. 
            This corresponds to the kernel size dimensions in the weight shape. 
            The resulting norm_weight tensor will have the shape (out_channels, in_channels).
            """
            weight_norm_along_kernel_dims = torch.norm(weight, dim=(-2, -1))
            return weight_norm_along_kernel_dims

    @property
    def bias(self):
        return self.layer.bias

    def forward(self, x: TensorType):
        return self.layer(x)


class BIMTConfig(BaseModel, extra="allow"):
    """
    layer_names: list of layer name strings
    distance_between_nearby_layers: something like 0.2
    scale: value by which the compute cost is scaled
    """

    layer_names: List[str]
    distance_between_nearby_layers: float
    scale: Union[float, None]
    device: str


class BIMTLoss(nn.Module):
    def __init__(
        self,
        layer_names: List[str],
        distance_between_nearby_layers: float = 0.2,
        scale: float = 1.0,
        device="cpu",
    ):
        super().__init__()
        self.distance_between_nearby_layers = distance_between_nearby_layers

        self.layer_names = layer_names
        self.scale = scale
        self.device = device

    def init_modules_for_training(self, model: nn.Module):
        """
        Converts the some of the layers to BrainInspiredLayer instances.
        The difference between a normal and a brain inspired layer is that the latter
        contains data about the locations of the neurons within the model (in_coordinates, out_coordinates)

        Returns:
            nn.Module: the model, but now with modified modules
        """
        for name in self.layer_names:
            layer = get_module_by_name(module=model, name=name)
            new_layer = BrainInspiredLayer(layer=layer, in_fold=1, out_fold=1)
            model = setattr_pytorch_model(model=model, name=name, item=new_layer)

        return model

    def fetch_all_brain_inspired_layers(self, model: nn.Module):
        return [
            get_module_by_name(module=model, name=name) for name in self.layer_names
        ]

    @classmethod
    def from_config(cls, config: BIMTConfig):
        return cls(
            layer_names=config.layer_names,
            distance_between_nearby_layers=config.distance_between_nearby_layers,
            scale=config.scale,
            device=config.device,
        )

    # XXXX
    def get_compute_cost(
        self,
        model: nn.Module,
        weight_factor=1.0,
        bias_penalize=False,
        no_penalize_last=True,
    ):
        layer_position_info = self.fetch_all_brain_inspired_layers(model=model)
        # compute connection cost
        cc = 0
        num_brain_inspired_layers = len(layer_position_info)
        for i in range(num_brain_inspired_layers):
            if i == num_brain_inspired_layers - 1 and no_penalize_last:
                weight_factor = 0.0

            brain_inspired_layer = layer_position_info[i]
            assert isinstance(
                brain_inspired_layer, BrainInspiredLayer
            ), f"Expected brain_inspired_layer to be an instance of BrainInspiredLayer, but got: {type(brain_inspired_layer)}\nMake sure that you are running `model = BIMTLoss(...).init_modules_for_training(model)` before using the forward method"

            dist = torch.abs(
                brain_inspired_layer.out_coordinates.unsqueeze(dim=1)
                - brain_inspired_layer.in_coordinates.unsqueeze(dim=0)
            )
            cc += torch.sum(
                torch.abs(brain_inspired_layer.weight)
                * (weight_factor * dist + self.distance_between_nearby_layers).to(
                    self.device
                )
            )
            if bias_penalize == True:
                assert (
                    brain_inspired_layer.bias is not None
                ), "Expected brain_inspired_layer.bias to not be None. Make sure that you set bias=True in your layer"
                cc += torch.sum(
                    torch.abs(brain_inspired_layer.bias)
                    * (self.distance_between_nearby_layers)
                )
        return cc

    def wandb_log(
        self,
        model: nn.Module,
        weight_factor=1.0,
        bias_penalize=False,
        no_penalize_last=True,
    ):
        wandb.log(
            {
                "bimt_loss": self.get_compute_cost(
                    model=model,
                    weight_factor=weight_factor,
                    bias_penalize=bias_penalize,
                    no_penalize_last=no_penalize_last,
                ).item()
            }
        )

    def forward(
        self,
        model: nn.Module,
        weight_factor=1.0,
        bias_penalize=False,
        no_penalize_last=True,
    ):
        if self.scale is not None:
            return self.scale * self.get_compute_cost(
                model=model,
                weight_factor=weight_factor,
                bias_penalize=bias_penalize,
                no_penalize_last=no_penalize_last,
            )
        else:
            return None
