from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from torch_geometric.data import Batch
from wandb import Image

from src.utils.backbone_utils import safe_log
from src.utils.sparse_utils import sparse_sum


def make_entropy_per_class_plot(
    trajectory: Float[Tensor, "n_steps n_edges"], data: Batch
) -> Image:
    """Plot the entropy of the predicted edge probabilities per parent node separately for each
    number of parents.

    :param trajectory: Trajectory of edge probabilities (during generation).
    :param data: PyG batch object.
    :return: WandB Image of the plot.
    """
    n_steps, _ = trajectory.shape
    entropies = defaultdict(list)

    for i in range(n_steps):
        edge_attr_pred = trajectory[i]
        entropy_per_class_dict = compute_entropy_per_class(edge_attr_pred, data)
        for k, entropy in entropy_per_class_dict.items():
            entropies[k].append(entropy)

    n_colors = len(entropies)
    colors = plt.cm.viridis(np.linspace(0, 1, n_colors))
    fig, ax = plt.subplots(figsize=(10, 10))
    for i, (k, entropy) in enumerate(entropies.items()):
        step = np.arange(0, n_steps)
        ax.plot(step, entropy, color=colors[i], label=f"K={k}")
    ax.set_xlabel("Step")
    ax.set_ylabel("Entropy")
    ax.set_ylim(0, 1)
    ax.legend()

    image = Image(fig)
    plt.close(fig)
    return image


def make_entropy_plot(trajectory: Float[Tensor, "n_steps n_edges"], data: Batch) -> Image:
    """Plot the entropy of the predicted edge probabilities.

    :param trajectory: Trajectory of edge probabilities (during generation).
    :param data: PyG batch object.
    :return: WandB Image of the plot.
    """
    n_steps, _ = trajectory.shape
    entropies = []
    for i in range(n_steps):
        edge_attr_pred = trajectory[i]
        entropy = compute_entropy(edge_attr_pred, data)
        entropies.append(entropy)

    entropy = np.array(entropies)
    step = np.arange(0, n_steps)

    plt.style.use("bmh")
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.plot(step, entropy)
    ax.set_xlabel("Step")
    ax.set_ylabel("Entropy")
    ax.set_ylim(0, 1)

    image = Image(fig)
    plt.close(fig)
    return image


def compute_entropy(edge_attr_pred: Float[Tensor, "n_edges"], data: Batch) -> float:
    """Compute the average normalized entropy of the edge probability distributions in a graph.

    :param edge_attr_pred: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: Average normalized entropy.
    """
    K = data.n_parents
    p = edge_attr_pred
    log_p = safe_log(p)
    entropy = -sparse_sum(data.edge_index, p * log_p, 1)
    mask = K >= 2
    entropy_norm = entropy[mask] / torch.log(K[mask])
    return entropy_norm.mean().item()


def compute_entropy_per_class(
    edge_attr_pred: Float[Tensor, "n_edges"], data: Batch
) -> defaultdict:
    """Compute the average normalized entropy of the edge probability distributions in a graph.
    Compute it separately for nodes with a different number of parents.

    :param edge_attr_pred: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: Dict with average normalized entropy per class.
    """
    K_unique = torch.unique(data.n_parents)[2:]
    K = data.n_parents
    p = edge_attr_pred
    log_p = safe_log(p)
    entropy = -sparse_sum(data.edge_index, p * log_p, 1)
    mask = K >= 2
    entropy_per_class_dict = {}
    for k in K_unique:
        mask = K == k
        entropy_norm = entropy[mask] / torch.log(K[mask])
        entropy_per_class_dict[k.item()] = entropy_norm.mean().item()
    return entropy_per_class_dict
