from ...grid.two_dimensional import BaseGrid2dLinear, BaseGrid2dConv
import torch.nn as nn
from typing import Union
from ...utils.grid_size import find_rectangle_dimensions
from ...bimt.loss import BrainInspiredLayer
from .core import ring_loss, ring_loss_1d
from ...grid.one_dimensional import BaseGrid1dLinear, BaseGrid1dConv


class RingLoss1D:
    """
    Penalizes frequency components of the cortical sheet which are
    outside of `radius_outer` and within `radius_inner`
    """

    def __init__(
        self,
        layer: Union[nn.Conv2d, nn.Linear, BrainInspiredLayer],
        freq_inner: float,
        freq_outer: float,
    ):
        ## if it's a brain inspired layer, then access the actual linear/conv layer within
        if isinstance(layer, BrainInspiredLayer):
            layer = layer.layer

        if isinstance(layer, nn.Linear):
            self.grid_container = BaseGrid1dLinear(linear_layer=layer)

        elif isinstance(layer, nn.Conv2d):
            self.grid_container = BaseGrid1dConv(conv_layer=layer)
        else:
            raise TypeError(
                f"Expected layer to be one of nn.Linear or nn.Conv2d but got: {type(layer)}"
            )

        self.freq_inner = freq_inner
        self.freq_outer = freq_outer

    def get_loss(self):
        return ring_loss_1d(
            cortical_sheet=self.grid_container.grid,
            freq_inner=self.freq_inner,
            freq_outer=self.freq_outer,
        )


class RingLoss:
    """
    Penalizes frequency components of the cortical sheet which are
    outside of `radius_outer` and within `radius_inner`
    """

    def __init__(
        self,
        layer: Union[nn.Conv2d, nn.Linear, BrainInspiredLayer],
        device: str,
        radius_inner: float,
        radius_outer: float,
    ):
        ## if it's a brain inspired layer, then access the actual linear/conv layer within
        if isinstance(layer, BrainInspiredLayer):
            layer = layer.layer

        if isinstance(layer, nn.Linear):
            output_size = layer.weight.shape[0]
            grid_size = find_rectangle_dimensions(area=output_size)

            self.grid_container = BaseGrid2dLinear(
                linear_layer=layer,
                height=grid_size.height,
                width=grid_size.width,
                device=device,
            )

        elif isinstance(layer, nn.Conv2d):
            grid_size = find_rectangle_dimensions(area=output_size)
            self.grid_container = BaseGrid2dConv(
                conv_layer=layer,
                height=grid_size.height,
                width=grid_size.width,
                device=device,
            )
        else:
            raise TypeError(
                f"Expected layer to be one of nn.Linear or nn.Conv2d but got: {type(layer)}"
            )

        self.device = device
        self.radius_inner = radius_inner
        self.radius_outer = radius_outer

    def get_loss(self):
        return ring_loss(
            cortical_sheet=self.grid_container.grid,
            radius_inner=self.radius_inner,
            radius_outer=self.radius_outer,
        )
