import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
from PIL import Image
import PIL
from ..utils.pca import get_tensor_principal_components
from ..utils import normalize_tensor
from ..configs import NesimConfig, NeighbourhoodCosineSimilarity
from ..losses.nesim_loss import NesimLoss


class WeightsMagnitudeViewer:
    @torch.no_grad()
    def __init__(
        self,
        model: nn.Module,
        checkpoint_filenames: list,
        layer_name: str,
        device="cuda:0",
        resize_height=256,
        resize_width=256,
    ):
        self.model = model.to(device).eval()
        self.checkpoint_filenames = checkpoint_filenames
        self.layer_name = layer_name

        self.resize_height = resize_height
        self.resize_width = resize_width

        first_checkpoint_filename = self.checkpoint_filenames[0]
        ## sanity check
        self.model.load_state_dict(torch.load(first_checkpoint_filename))

        nesim_config = NesimConfig(
            layer_wise_configs=[
                NeighbourhoodCosineSimilarity(layer_name=self.layer_name, scale=None)
            ]
        )

        self.nesim_loss = NesimLoss(
            model=self.model, config=nesim_config, device=device
        )

    def __len__(self):
        return len(self.checkpoint_filenames)

    @torch.no_grad()
    def __getitem__(self, idx):

        self.model.load_state_dict(torch.load(self.checkpoint_filenames[idx]))
        grid = self.nesim_loss.get_all_grid_states()[
            "neighbourhood_cossim_loss_" + self.layer_name
        ].detach()

        grid_magnitude = torch.linalg.norm(grid, dim=2, ord=1).unsqueeze(-1)
        grid_magnitude = normalize_tensor(grid_magnitude)
        """
        RED: high magnitude
        BLUE: low magnitude
        """
        grid_magnitude = torch.concat(
            [
                grid_magnitude,  # red
                torch.zeros_like(grid_magnitude),  # green has no effect
                1 - grid_magnitude,  # blue
            ],
            dim=-1,
        )
        # return as RGB pil image
        return Image.fromarray(
            (grid_magnitude.cpu().numpy() * 255).astype(np.uint8)
        ).resize(
            (self.resize_height, self.resize_width),
            resample=PIL.Image.Resampling.NEAREST,
        )
