from typing import Union, List
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchtyping import TensorType
from pydantic import BaseModel, Extra, Field
import wandb
import json
from einops import rearrange
from ...bimt.loss import BrainInspiredLayer
from ...utils.hook import ForwardHook
from ...utils.getting_modules import get_module_by_name
from ...utils.grid_size import find_rectangle_dimensions
from ...utils.correlation import pearsonr


class SinglePairConfig(BaseModel, extra=Extra.forbid):
    """
    - `layer1_name`: name of first layer in pair
    - `layer2_name`: name of 2nd layer in pair
    - `scale`: either float or None. None means loss wont be backpropagated.
    - `downsample_factor`: factor by which we downsample outputs before calculating neuron-wise cossim
    """

    layer1_name: str
    layer2_name: str
    downsample_factor: float
    scale: Union[None, float]


class CrossLayerCorrelationLossConfig(BaseModel, extra=Extra.forbid):
    """
    `pair_configs`: list of SinglePairConfig instances
    """

    pair_configs: List[SinglePairConfig]

    def save_json(self, filename: str):
        with open(filename, "w") as file:
            json.dump(self.dict(), file, indent=4)

    @classmethod
    def from_json(cls, filename: str):
        with open(filename, "r") as file:
            json_data = json.load(file)
        return cls.parse_obj(json_data)


class SinglePairHookContainer:
    def __init__(
        self,
        layer1: Union[nn.Linear, nn.Conv2d, BrainInspiredLayer],
        layer2: Union[nn.Linear, nn.Conv2d, BrainInspiredLayer],
        downsample_factor: float = None,
        scale: float = None,
    ):

        self.validate_inputs(layer1=layer1, layer2=layer2)

        if isinstance(layer1, BrainInspiredLayer):
            layer1 = layer1.layer
        if isinstance(layer2, BrainInspiredLayer):
            layer2 = layer2.layer

        self.layer1_hook = ForwardHook(module=layer1)
        self.layer2_hook = ForwardHook(module=layer2)
        self.scale = scale
        self.downsample_factor = downsample_factor

    def validate_inputs(self, layer1, layer2):
        assert isinstance(
            layer1, Union[nn.Linear, nn.Conv2d, BrainInspiredLayer]
        ), f"Expected nn.Linear or nn.Conv2d or BrainInspiredLayer but got: {layer1}"
        assert isinstance(
            layer2, Union[nn.Linear, nn.Conv2d, BrainInspiredLayer]
        ), f"Expected nn.Linear or nn.Conv2d or BrainInspiredLayer but got: {layer2}"


class CrossLayerCorrelationLoss(nn.Module):
    def __init__(self, pair_hook_containers: List[SinglePairHookContainer]):
        super().__init__()
        """
        we know that 2 layers are "functionally aligned" if their outputs are similar along the
        last dim for linear layers or the channel dim for conv layers.
        """
        self.pair_hook_containers = pair_hook_containers

    @classmethod
    def from_config(cls, model: nn.Module, config: CrossLayerCorrelationLossConfig):

        pair_hook_containers = []
        for pair_config in config.pair_configs:

            layer1 = get_module_by_name(module=model, name=pair_config.layer1_name)
            layer2 = get_module_by_name(module=model, name=pair_config.layer2_name)

            single_pair_hook_container = SinglePairHookContainer(
                layer1=layer1,
                layer2=layer2,
                scale=pair_config.scale,
                downsample_factor=pair_config.downsample_factor,
            )
            pair_hook_containers.append(single_pair_hook_container)

        return cls(pair_hook_containers=pair_hook_containers)

    def take_mean_along_height_width_if_4d(
        self,
        x: Union[
            TensorType["batch", "num_output_neurons"],
            TensorType["batch", "num_channels", "h", "w"],
        ],
    ):
        assert (
            x.ndim == 4 or x.ndim == 2
        ), "Expected inpuy to be either a 2d tensor or a 4d tensor"

        if x.ndim == 4:
            ## nchw -> nc
            x = x.mean(-1).mean(-1)

        assert x.ndim == 2

        return x

    def two_dimensional_wiring_loss(
        self,
        y1: TensorType["batch", "num_output_neurons_1"],
        y2: TensorType["batch", "num_output_neurons_2"],
        downsample_factor=2.0,
    ):

        assert (
            y1.shape[0] == y2.shape[0]
        ), "Expected both y1 and y2 to have the same batch size"
        """
        use einops
        grid_y1 = ('batch', h , w) where h * w = num_output_neurons_1
        grid_y2 = ('batch', h , w) where h * w = num_output_neurons_2
        """
        grid_size_y1 = find_rectangle_dimensions(area=y1.shape[1])
        grid_size_y2 = find_rectangle_dimensions(area=y2.shape[1])

        grid_y1 = rearrange(
            y1, "batch (h w) -> batch h w", h=grid_size_y1.height, w=grid_size_y1.width
        ).unsqueeze(1)

        grid_y2 = rearrange(
            y2, "batch (h w) -> batch h w", h=grid_size_y2.height, w=grid_size_y2.width
        ).unsqueeze(1)

        """
        grid_y1.shape: (batch, 1, h1, w1)
        grid_y2.shape: (batch, 1, h2, w2)

        now I want to calculate the spatial correlation between elements in grid_y1 and grid_y2
        """
        h1, w1 = grid_y1.shape[2], grid_y1.shape[3]
        h2, w2 = grid_y2.shape[2], grid_y2.shape[3]

        if w1 > w2:
            # Shrink grid_y1 along w to match w2
            grid_y1 = F.interpolate(grid_y1, size=(h1, w2), mode="bilinear")
            w1 = w2
        elif w2 > w1:
            # Shrink grid_y2 along w to match w1
            grid_y2 = F.interpolate(grid_y2, size=(h2, w1), mode="bilinear")
            w2 = w1

        if h1 > h2:
            # Shrink grid_y1 along h to match h2
            grid_y1 = F.interpolate(grid_y1, size=(h2, w1), mode="bilinear")
            h2 = h1
        elif h2 > h1:
            # Shrink grid_y2 along h to match h1
            grid_y2 = F.interpolate(grid_y2, size=(h1, w2), mode="bilinear")
            h1 = h2

        ## make sure they're of the same shape after resizing
        assert grid_y1.shape == grid_y2.shape, f"{grid_y1.shape} {grid_y2.shape}"

        if downsample_factor is not None:
            grid_y1 = F.interpolate(
                grid_y1,
                size=(
                    max(1, int(grid_y1.shape[2] // downsample_factor)),
                    max(1, int(grid_y1.shape[3] // downsample_factor)),
                ),
                mode="bilinear",
            )

            grid_y2 = F.interpolate(
                grid_y2,
                size=(
                    max(1, int(grid_y2.shape[2] // downsample_factor)),
                    max(1, int(grid_y2.shape[3] // downsample_factor)),
                ),
                mode="bilinear",
            )

        grid_y1_1d = rearrange(grid_y1.squeeze(1), "b h w -> b (h w)")

        grid_y2_1d = rearrange(grid_y2.squeeze(1), "b h w -> b (h w)")

        return 1 - pearsonr(x=grid_y1_1d, y=grid_y2_1d).mean()

    def compute_for_single_pair(self, pair_hook_container: SinglePairHookContainer):
        assert (
            pair_hook_container.layer1_hook.output is not None
        ), "Did you run a forward pass on the model yet?"
        assert (
            pair_hook_container.layer2_hook.output is not None
        ), "Did you run a forward pass on the model yet?"

        y1 = pair_hook_container.layer1_hook.output
        y2 = pair_hook_container.layer2_hook.output

        y1 = self.take_mean_along_height_width_if_4d(x=y1)
        y2 = self.take_mean_along_height_width_if_4d(x=y2)

        assert y1.shape[0] == y2.shape[0], "Expected batch sizes to match :("

        loss = self.two_dimensional_wiring_loss(
            y1=y1, y2=y2, downsample_factor=pair_hook_container.downsample_factor
        )

        return loss

    def compute(self):
        losses = []

        for pair_hook_container in self.pair_hook_containers:
            loss = self.compute_for_single_pair(pair_hook_container=pair_hook_container)
            losses.append(loss)

        return sum(losses)

    def forward(self):
        losses = []

        for pair_hook_container in self.pair_hook_containers:

            if pair_hook_container.scale is not None:
                loss = self.compute_for_single_pair(
                    pair_hook_container=pair_hook_container
                )
                losses.append(pair_hook_container.scale * loss)

        if len(losses) > 0:
            return sum(losses)
        else:
            return None

    def wandb_log(self):

        wandb.log({"cross_layer_correlation_loss": self.compute().item()})
