import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from uuid import uuid4

import matplotlib.pyplot as plt
import torch
from einops import rearrange
from PIL import Image
from tqdm import trange

from custom_colbert.interpretability.plot_utils import plot_patches
from custom_colbert.interpretability.processor import ColPaliProcessor
from custom_colbert.interpretability.torch_utils import normalize_attention_map_per_query_token
from custom_colbert.interpretability.vit_configs import VIT_CONFIG
from custom_colbert.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")


@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int


def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    savedir: str | Path | None = None,
    add_special_prompt_to_doc: bool = True,
) -> None:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")

    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]

    # Handle savepath
    if not savedir:
        savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
        print(f"No savepath provided. Results will be saved to: `{savedir}`.")
    elif isinstance(savedir, str):
        savedir = Path(savedir)
    savedir.mkdir(parents=True, exist_ok=True)

    # Resize the image to square
    input_image_square = image.resize((vit_config.resolution, vit_config.resolution))

    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model.forward(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)

    # NOTE: `output_image`` will have shape:
    # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
    # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
    with torch.no_grad():
        output_image = model.forward(**asdict(input_image_processed))

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[
            :, : processor.processor.image_seq_length, :
        ]  # (1, n_patch_x * n_patch_y, hidden_dim)

    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
    )  # (1, n_patch_x, n_patch_y, hidden_dim)

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = normalize_attention_map_per_query_token(
        attention_map
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)

    # Get text token information
    n_tokens = input_text_processed.input_ids.size(1)
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    print("Text tokens:")
    pprint.pprint(text_tokens)
    print("\n")

    for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."):  # exclude the <bos> and the "\n" tokens
        fig, axis = plot_patches(
            input_image_square,
            vit_config.patch_size,
            vit_config.resolution,
            patch_opacities=attention_map_normalized[0, token_idx, :, :],
            style="dark_background",
        )

        fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        savepath = savedir / f"token_{token_idx}.png"
        fig.savefig(savepath)
        print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
        plt.close(fig)

    return
