import h5py
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image

from sae.utils import filter_out_nosiy_activation
from utils.dataset_loader import ImageDatasetLoader
from utils.image_processor import ImageProcessor


def plot_top_reference_images(
    latent_indices, backbone, latent_names=None, n=5, root="../../", image_dir="reference_images"
):
    """Plot top n reference images."""

    fig, axes = plt.subplots(abs(n), 1, figsize=(12, 10))
    if n > 0:
        candidates = latent_indices[:n]
    else:
        candidates = latent_indices[n:]
    for i, latent in enumerate(candidates):

        image = Image.open(f"{root}/out/{image_dir}/{backbone}/{latent}.jpg")
        axes[i].imshow(image)
        axes[i].axis("off")

    plt.tight_layout()
    plt.show()
    plt.close()


def plot_top_images(selected_indices, dataset, show=True):

    for key in ["image", "jpg", "webp"]:
        if key in dataset.features:
            break

    num_images = len(selected_indices)
    num_cols = (num_images + 1) // 2  # Ensure at least 1 column if num_images < 2

    fig, axes = plt.subplots(2, num_cols, figsize=(3 * num_cols, 6))

    images = []
    for idx, selected_idx in enumerate(selected_indices):
        row = idx % 2  # 2 rows
        col = idx // 2  # Dynamic number of columns

        img = dataset[key][int(selected_idx)]
        if isinstance(img, str):
            img = Image.open(img)
        images.append(img)
        img_resized = img.resize((128, 128))
        axes[row, col].imshow(img_resized)
        axes[row, col].axis("off")
    plt.tight_layout()
    if show:
        plt.show()
    plt.close()
    return images


class SAEVisualizer:
    def __init__(self, sae, vit, cfg, root, sae_name, device, dataset_name="imagenet", split="train"):
        self.sae = sae
        self.vit = vit
        self.cfg = cfg
        self.device = device
        self.dataset = ImageDatasetLoader.load_dataset(dataset_name, seed=1, split=split)
        self.max_act_imgs_imagenet, self.mean_acts_imagenet = self.get_max_acts_and_images(
            feat_data_root=f"{root}/out/feature_data", sae_name=sae_name
        )
        self.max_act_imgs, self.mean_acts = self.get_max_acts_and_images(
            feat_data_root=f"{root}/out/feature_data", sae_name=sae_name, dataset_name=dataset_name, split=split
        )
        self.classnames = self.get_class_names(root)

    def get_class_names(self, root):
        filename = f"{root}/configs/classnames/imagenet_classnames.txt"
        with open(filename, "r") as file:
            class_names = [" ".join(line.strip().split(" ")[1:]) for line in file.readlines()]
        return class_names

    def get_max_acts_and_images(
        self, feat_data_root: str, sae_name: str, dataset_name: str = "imagenet", split: str = "train"
    ) -> tuple[dict, dict]:
        with h5py.File(f"{feat_data_root}/{sae_name}/{dataset_name}/{split}_sae_stats.h5") as hf:
            max_act_imgs = hf["max_activating_image_indices"][:].astype(int)
            mean_acts = hf["sae_mean_acts"][:]
            if mean_acts.max() > 1:
                sparsity = hf["sae_sparsity"][:]
                mean_acts /= sparsity

        return max_act_imgs, mean_acts

    def get_max_activating_images_and_labels(self, latent_idx, image_key="image", label_key="label"):
        img_list = self.max_act_imgs[latent_idx]
        images = []
        labels = []
        for key in ["image", "jpg", "webp"]:
            if key in self.dataset.features:
                break
        for i in img_list:
            images.append(self.dataset[i.item()][key])
            try:
                labels.append(self.dataset[i.item()]["label"])
            except:
                continue

        return images, labels

    def get_segmentation_mask(self, images, latent_idx: int, base_opacity: int, resize_size=224, filter_noise=True):

        batch_dict = {"image": []}
        for i, image in enumerate(images):
            if isinstance(image, str):
                image = Image.open(image)
            if image.mode == "L":
                image = image.convert("RGB")
            batch_dict["image"].append(image)

        inputs = ImageProcessor.process_model_inputs(batch_dict, self.vit, self.device)
        sae_act = ImageProcessor.get_sae_activations(
            self.sae, self.vit, inputs, self.cfg.block_layer, self.cfg.module_name, self.cfg.class_token, get_mean=False
        )
        if filter_noise:
            sae_act = filter_out_nosiy_activation(self.mean_acts_imagenet, sae_act.detach())
        selected_act = sae_act[:, :, latent_idx]
        feature_size = int(np.sqrt(selected_act.shape[1] - 1))

        masks = torch.Tensor(selected_act[:, 1:].reshape(sae_act.shape[0], 1, feature_size, feature_size))
        masks = (
            torch.nn.functional.interpolate(masks, (resize_size, resize_size), mode="bilinear").squeeze(1).cpu().numpy()
        )

        masked_images = []
        for i, image in enumerate(batch_dict["image"]):
            image_array = np.array(image.resize((resize_size, resize_size)))[..., :3]
            mask = (masks[i] - masks[i].min()) / (masks[i].max() - masks[i].min() + 1e-10)

            rgba_overlay = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
            rgba_overlay[..., :3] = image_array[..., :3]

            darkened_image = (image_array[..., :3] * (base_opacity / 255)).astype(np.uint8)
            rgba_overlay[mask == 0, :3] = darkened_image[mask == 0]
            rgba_overlay[..., 3] = 255

            masked_images.append(Image.fromarray(rgba_overlay))

        return masked_images

    def plot_images(
        self,
        images,
        labels=None,
        top_k=5,
        show_plot=True,
    ):
        images = [img.resize((224, 224)) for img in images]
        num_cols = min(top_k, 5)
        num_rows = (top_k + num_cols - 1) // num_cols
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(4.5 * num_cols, 5 * num_rows))
        axes = axes.flatten()

        for i in range(top_k):
            axes[i].imshow(images[i])
            axes[i].axis("off")
            if labels is not None:
                class_name = self.classnames[int(labels[i])]
                axes[i].set_title(f"{labels[i]} {class_name}", fontsize=25)
        plt.tight_layout()

        if show_plot:
            plt.show()
        plt.close()

        return fig

    def get_top_images(
        self,
        latent_idx: int,
        top_k=5,
        show_seg_mask=False,
        base_opacity=50,
        with_label=True,
        show_plot=True,
        filter_noise=True,
    ):

        images, labels = self.get_max_activating_images_and_labels(latent_idx)

        if not with_label:
            labels = None

        if show_seg_mask:
            images = self.get_segmentation_mask(images[:top_k], latent_idx, base_opacity, filter_noise=filter_noise)

        fig = self.plot_images(
            images,
            labels,
            top_k=top_k,
            show_plot=show_plot,
        )

        return fig
