from .core import downscale_upscale_loss
from ...grid.two_dimensional import BaseGrid2dLinear, BaseGrid2dConv

from typing import List
import torch.nn as nn
from typing import Union
from ...utils.grid_size import find_rectangle_dimensions
from ...bimt.loss import BrainInspiredLayer


class LaplacianPyramidLoss:
    def __init__(
        self,
        layer: Union[nn.Conv2d, nn.Linear, BrainInspiredLayer],
        device: str,
        factor_w: List[float] = [2.0],
        factor_h: List[float] = [2.0],
    ):

        ## if it's a brain inspired layer, then access the actual linear/conv layer within
        if isinstance(layer, BrainInspiredLayer):
            layer = layer.layer

        output_size = layer.weight.shape[0]
        grid_size = find_rectangle_dimensions(area=output_size)

        if isinstance(layer, nn.Linear):
            
            self.grid_container = BaseGrid2dLinear(
                linear_layer=layer,
                height=grid_size.height,
                width=grid_size.width,
                device=device,
            )

        elif isinstance(layer, nn.Conv2d):

            self.grid_container = BaseGrid2dConv(
                conv_layer=layer,
                height=grid_size.height,
                width=grid_size.width,
                device=device
            )
        else:
            raise TypeError(
                f"Expected layer to be one of nn.Linear or nn.Conv2d but got: {type(layer)}"
            )

        self.factor_w = factor_w
        self.factor_h = factor_h
        self.device = device

    def get_loss(self):
        losses = []
        for factor_h, factor_w in zip(self.factor_w, self.factor_h):

            loss = downscale_upscale_loss(
                grid=self.grid_container.grid, factor_w=factor_w, factor_h=factor_h
            )
            losses.append(loss)

        return sum(losses)
