import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
from PIL import Image
import PIL
from ..utils.pca import get_tensor_principal_components
from ..utils import normalize_tensor
from ..configs import NesimConfig, NeighbourhoodCosineSimilarity
from ..losses.nesim_loss import NesimLoss


class WeightsPCAViewer:
    @torch.no_grad()
    def __init__(
        self,
        model: nn.Module,
        checkpoint_filenames: list,
        layer_name: str,
        device="cuda:0",
        resize_height=256,
        resize_width=256,
        scale_by_magnitude: bool = True,
        load_from_brain_inspired_layers: bool = True,
    ):
        self.model = model.to(device).eval()
        self.checkpoint_filenames = checkpoint_filenames
        self.layer_name = layer_name
        self.scale_by_magnitude = scale_by_magnitude

        self.resize_height = resize_height
        self.resize_width = resize_width

        nesim_config = NesimConfig(
            layer_wise_configs=[
                NeighbourhoodCosineSimilarity(layer_name=self.layer_name, scale=None)
            ]
        )

        self.nesim_loss = NesimLoss(
            model=self.model, config=nesim_config, device=device
        )

        ## grid: h, w, e
        grid = self.nesim_loss.layer_handlers[0].layer_loss.grid_container.grid
        self.grid_height = grid.shape[0]
        self.grid_width = grid.shape[1]

        grid = rearrange(grid, "h w e -> (h w) e")

        ## 3 components: R,G,B
        self.pca = get_tensor_principal_components(grid, n_components=3)
        self.load_from_brain_inspired_layers = load_from_brain_inspired_layers

    def __len__(self):
        return len(self.checkpoint_filenames)

    @torch.no_grad()
    def __getitem__(self, idx):

        state_dict = torch.load(self.checkpoint_filenames[idx])

        if self.load_from_brain_inspired_layers:
            state_dict_fixed = {}
            for key in state_dict:
                state_dict_fixed[
                    key.replace("model.", "").replace(".layer.", ".")
                ] = state_dict[key]
            state_dict = state_dict_fixed

        self.model.load_state_dict(state_dict)
        grid = self.nesim_loss.get_all_grid_states()[
            "neighbourhood_cossim_loss_" + self.layer_name
        ].detach()
        ## get low dimensional proj
        projected_grid = grid @ self.pca.components

        if self.scale_by_magnitude:
            grid_magnitude = torch.linalg.norm(grid, dim=2, ord=1)
            projected_grid = projected_grid * grid_magnitude.unsqueeze(-1)

        # return as RGB pil image
        return Image.fromarray(
            (normalize_tensor(projected_grid).cpu().numpy() * 255).astype(np.uint8)
        ).resize(
            (self.resize_height, self.resize_width),
            resample=PIL.Image.Resampling.NEAREST,
        )
