import einops
import torch.nn.functional as F
from torchtyping import TensorType

from einops import rearrange


def downscale_upscale_loss(grid: TensorType, factor_w: float, factor_h: float):
    assert grid.ndim == 3, "Expected grid to be a 3d tensor of shape (h, w, e)"
    grid = rearrange(grid, "h w e -> e h w").unsqueeze(0)

    assert (
        factor_h <= grid.shape[1]
    ), f"Expected factor_h to be <= grid.shape[1] = {grid.shape[1]} but got: {factor_h}"
    assert (
        factor_w <= grid.shape[2]
    ), f"Expected factor_w to be <= grid.shape[2] = {grid.shape[2]} but got: {factor_w}"
    # Downscale the grid tensor
    downscaled_grid = F.interpolate(
        grid, scale_factor=(1 / factor_h, 1 / factor_w), mode="bilinear"
    )
    # Upscale the downscaled grid tensor
    upscaled_grid = F.interpolate(downscaled_grid, size=grid.shape[2:], mode="bilinear")

    # Calculate the MSE loss between the original grid and upscaled grid
    # loss = F.mse_loss(upscaled_grid, grid)

    grid = rearrange(grid.squeeze(0), "e h w -> (h w) e")
    upscaled_grid = rearrange(upscaled_grid.squeeze(0), "e h w -> (h w) e")
    loss = 1 - F.cosine_similarity(grid, upscaled_grid, dim=-1).mean()

    return loss
