import torch
import torch.nn as nn
import numpy as np
from typing import Union
from ...grid.two_dimensional import BaseGrid2dLinear, BaseGrid2dConv
from ...utils.grid_size import find_rectangle_dimensions
from ...bimt.loss import BrainInspiredLayer
import einops


class NeighbourhoodCosineSimilarityLoss:
    def __init__(
        self,
        layer: Union[nn.Conv2d, nn.Linear, BrainInspiredLayer],
        device: str,
        kernel_size: int = 3,
        stride: int = 1,
        padding: int = 0,
    ):
        output_size = layer.weight.shape[0]
        grid_size = find_rectangle_dimensions(area=output_size)

        ## if it's a brain inspired layer, then access the actual linear/conv layer within
        if isinstance(layer, BrainInspiredLayer):
            layer = layer.layer

        if isinstance(layer, nn.Linear):
            self.grid_container = BaseGrid2dLinear(
                linear_layer=layer,
                height=grid_size.height,
                width=grid_size.width,
                device=device,
            )

        elif isinstance(layer, nn.Conv2d):

            self.grid_container = BaseGrid2dConv(
                conv_layer=layer,
                height = grid_size.height,
                width = grid_size.width,
                device=device,
            )
        else:
            raise TypeError(
                f"Expected layer to be one of nn.Linear or nn.Conv2d but got: {type(layer)}"
            )
        self.device = device
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def get_loss_original(self):
        """
        This is old, but works
        """
        # NOTE: height comes before width in indexing
        all_neurons_cossims_sums = torch.zeros(
            self.grid_container.height, self.grid_container.width
        ).to(self.device)
        for i in range(self.grid_container.height):
            for j in range(self.grid_container.width):
                # Select the current neuron
                current_neuron = self.grid_container.grid[i, j, :]

                # Calculate the cosine similarity with respect to nearby neurons
                cossims = torch.nn.functional.cosine_similarity(
                    ## +2 because we go upto +1
                    self.grid_container.grid[
                        max(i - 1, 0) : min(i + 2, self.grid_container.height),
                        max(j - 1, 0) : min(j + 2, self.grid_container.width),
                        :,
                    ],
                    current_neuron.unsqueeze(0).unsqueeze(0),
                    dim=2,
                )

                cossim_sum = cossims.mean()
                all_neurons_cossims_sums[i, j] = cossim_sum

        return 1.0 - all_neurons_cossims_sums.mean()

    def get_loss_fast(self):
        """
        New, but does not work as of yet.

        Returns:
            torch.Tensor: The calculated loss.
        """
        grid = einops.rearrange(self.grid_container.grid, "h w e -> e h w").unsqueeze(0)
        # Calculate unfolded views of the grid
        unfolded = torch.nn.functional.unfold(
            grid, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
        )
        assert unfolded.shape[0] == 1
        assert unfolded.shape[1] == grid.shape[1] * self.kernel_size * self.kernel_size
        # assert unfolded.shape[2] == grid.shape[2] * grid.shape[3]
        unfolded = einops.rearrange(
            unfolded, "one sample_size num_samples -> one num_samples sample_size"
        ).squeeze(0)

        ## unfolded.shape = num_samples, sample_size
        ## sample_size = e * k * k

        unfolded = einops.rearrange(
            unfolded,
            "num_samples (e k1 k2) -> k1 k2 (num_samples e)",
            e=grid.shape[1],
            k1=self.kernel_size,
            k2=self.kernel_size,
        )
        cossim = torch.nn.functional.cosine_similarity(
            x1=unfolded[self.kernel_size // 2, self.kernel_size // 2, :]
            .unsqueeze(0)
            .unsqueeze(0),
            x2=unfolded,
            dim=2,
        )
        # raise AssertionError(cossim)
        # Calculate and return the loss
        return 1.0 - cossim.mean()

    def get_loss(self):
        return self.get_loss_fast()
