import os
from functools import partial
from typing import cast

import matplotlib.pyplot as plt
import torch
from matplotlib import patches
from sentence_transformers import SentenceTransformer

from datasets import Dataset
from mow.common.data import (
    indexing,
    prepare_batch_data,
    prepare_graph_representation,
)
from mow.modules.routers import GraphRouter
from mow.scripts.train_router import TrainRouterConfig


def plot_router_results(
    config: TrainRouterConfig,
    output: str,
    *,
    color_with_class: bool = False,
):
    router = cast(
        GraphRouter,
        GraphRouter.from_pretrained(config.train_config.output_dir / "best"),
    )
    router.eval()

    sentence_transformer = SentenceTransformer(
        config.sentence_transformer_model
    )
    sentence_transformer.to(
        torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )
    sentence_transformer.eval()

    datasets = {
        name: Dataset.load_from_disk(dataset / "test")
        .take(200)
        .map(partial(indexing, dataset_idx=i), desc=f"Indexing {name} dataset")
        .map(
            partial(
                prepare_graph_representation,
                sentence_transformer=sentence_transformer,
            ),
            batched=False,
            desc=f"Mapping {name} dataset",
        )
        .shuffle(seed=42)
        for i, (name, dataset) in enumerate(config.datasets.items())
    }
    for dataset in datasets.values():
        dataset.set_format(
            type="pt",
            columns=[
                "context",
                "nodes",
                "adjacency_matrix",
                "relation_matrix",
                "labels",
            ],
        )

    out = [
        router(**prepare_batch_data(dataset)) for dataset in datasets.values()
    ]
    logits = torch.stack([o["logits"] for o in out])
    embeddings = torch.stack([o["embedding"] for o in out])

    os.makedirs(output, exist_ok=True)

    def draw_fig(x: torch.Tensor, name: str):
        prototypes, examples = torch.chunk(x, 2, dim=1)
        p = prototypes.mean(dim=1)
        ex = examples.mean(dim=1)

        sims = torch.nn.functional.cosine_similarity(
            p.unsqueeze(1), ex.unsqueeze(0), dim=-1
        )
        sims = sims.detach().cpu().numpy()

        plt.figure(figsize=(10, 10))
        plt.imshow(sims, cmap="viridis", interpolation="nearest")
        plt.colorbar()

        # Red box around the diagonal
        for i in range(len(datasets)):
            plt.gca().add_patch(
                patches.Rectangle(
                    (i - 0.5, i - 0.5),
                    1,
                    1,
                    fill=False,
                    edgecolor="red",
                    linewidth=2,
                )
            )

        plt.title("Cosine Similarity between Prototypes and Examples")
        plt.xlabel("Examples")
        plt.ylabel("Prototypes")
        plt.xticks(range(len(datasets)), list(datasets.keys()), rotation=45)
        plt.yticks(range(len(datasets)), list(datasets.keys()))
        plt.tight_layout()
        plt.savefig(name, dpi=300)
        plt.close()

    draw_fig(logits, f"{output}/logits.png")
    for l in range(embeddings.shape[2]):
        draw_fig(embeddings[:, :, l], f"{output}/embeddings_layer_{l}.png")
