import torch
import torch.nn as nn
from einops import rearrange
import numpy as np
from PIL import Image
import PIL
from ..utils import normalize_tensor
from ..configs import NesimConfig, NeighbourhoodCosineSimilarity
from ..losses.nesim_loss import NesimLoss
from ..losses.ring.core import get_power_spectrum


class WeightsPowerSpectrumViewer:
    @torch.no_grad()
    def __init__(
        self,
        model: nn.Module,
        checkpoint_filenames: list,
        layer_name: str,
        device="cuda:0",
        resize_height=256,
        resize_width=256,
        load_from_brain_inspired_layers: bool = True,
    ):
        self.model = model.to(device).eval()
        self.checkpoint_filenames = checkpoint_filenames
        self.layer_name = layer_name
        self.load_from_brain_inspired_layers = load_from_brain_inspired_layers

        self.resize_height = resize_height
        self.resize_width = resize_width

        nesim_config = NesimConfig(
            layer_wise_configs=[
                NeighbourhoodCosineSimilarity(layer_name=self.layer_name, scale=None)
            ]
        )

        self.nesim_loss = NesimLoss(
            model=self.model, config=nesim_config, device=device
        )

    def __len__(self):
        return len(self.checkpoint_filenames)

    @torch.no_grad()
    def __getitem__(self, idx):

        state_dict = torch.load(self.checkpoint_filenames[idx])

        if self.load_from_brain_inspired_layers:
            state_dict_fixed = {}
            for key in state_dict:
                state_dict_fixed[
                    key.replace("model.", "").replace(".layer.", ".")
                ] = state_dict[key]
            state_dict = state_dict_fixed

        self.model.load_state_dict(state_dict)
        grid = self.nesim_loss.get_all_grid_states()[
            "neighbourhood_cossim_loss_" + self.layer_name
        ].detach()
        grid = rearrange(grid, "h w e -> e h w")
        spectrum = get_power_spectrum(cortical_sheet=grid).mean(0)
        # return as RGB pil image
        return Image.fromarray(
            (normalize_tensor(spectrum).cpu().numpy() * 255).astype(np.uint8)
        ).resize(
            (self.resize_height, self.resize_width),
            resample=PIL.Image.Resampling.NEAREST,
        )
