from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from torch_geometric.data import Batch
from torchmetrics import MeanMetric
from wandb import Image

from src.utils.sparse_utils import sparse_mode


def make_valid_binary_tree_per_graph_size_plot(
    edge_attr_pred: Float[Tensor, "n_edges"], data: Batch
) -> Image:
    """Plot the fraction of generated jets that are binary trees per number of nodes.

    :param edge_attr_pred: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: WandB Image of the plot.
    """
    valid_binary_trees = defaultdict(list)
    mode = sparse_mode(data.edge_index, edge_attr_pred)
    mode_counts = torch.bincount(mode, minlength=data.num_nodes)
    mode_counts[0] = 0

    for graph in range(data.num_graphs):
        n = (data.batch == graph).sum().item()
        count = mode_counts[data.batch == graph]
        valid_binary_trees[n] = valid_binary_trees[n] + [torch.all((count == 0) | (count == 2))]

    for n in valid_binary_trees:
        valid_binary_trees[n] = torch.stack(valid_binary_trees[n]).float().mean().item()

    x = np.array(list(valid_binary_trees.keys()))
    y = np.array(list(valid_binary_trees.values()))

    plt.style.use("bmh")
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.bar(x, y)
    ax.set_xlabel("Number of nodes")
    ax.set_ylabel("Valid Binary Tree")
    ax.set_ylim(0, 1)

    image = Image(fig)
    plt.close(fig)
    return image


def make_valid_binary_parent_plot(
    trajectory: Float[Tensor, "n_steps n_edges"], data: Batch
) -> Image:
    """Plot the fraction of parent nodes that have exactly two children.

    :param trajectory: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: WandB Image of the plot.
    """
    n_steps, _ = trajectory.shape
    valid_parents = []
    for i in range(n_steps):
        edge_attr_pred = trajectory[i]
        valid_parents.append(compute_valid_binary_parent(edge_attr_pred, data))

    x = np.arange(0, n_steps)
    y = np.array(valid_parents)

    plt.style.use("bmh")
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.plot(x, y)
    ax.set_xlabel("Step")
    ax.set_ylabel("Valid Parents")
    ax.set_ylim(0, 1)

    image = Image(fig)
    plt.close(fig)
    return image


def make_valid_binary_tree_plot(
    trajectory: Float[Tensor, "n_steps n_edges"], data: Batch
) -> Image:
    """Plot the fraction of generated jets that are binary trees.

    :param trajectory: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: WandB Image of the plot.
    """
    n_steps, _ = trajectory.shape
    valid_binary_trees = []
    for i in range(n_steps):
        edge_attr_pred = trajectory[i]
        valid_binary_trees.append(compute_valid_binary_tree(edge_attr_pred, data))

    x = np.arange(0, n_steps)
    y = np.array(valid_binary_trees)

    plt.style.use("bmh")
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.plot(x, y)
    ax.set_xlabel("Step")
    ax.set_ylabel("Valid Binary Tree")
    ax.set_ylim(0, 1)

    image = Image(fig)
    plt.close(fig)
    return image


def compute_valid_binary_parent(edge_attr_pred: Float[Tensor, "n_edges"], data: Batch) -> float:
    """Compute the fraction of parent nodes that have exactly two children.

    :param edge_attr_pred: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: Fraction of parent nodes that have exactly two children.
    """
    mode = sparse_mode(data.edge_index, edge_attr_pred)
    mode_counts = torch.bincount(mode, minlength=data.num_nodes)
    n_parents = torch.sum(data.parent_mask)
    n_parents_valid = torch.sum(mode_counts[data.parent_mask] == 2)
    return (n_parents_valid / n_parents).item()


def compute_valid_binary_tree(edge_attr_pred: Float[Tensor, "n_edges"], data: Batch) -> float:
    """Compute the fraction of generated jets that are binary trees.

    :param edge_attr_pred: Predicted edge probabilities.
    :param data: PyG batch object.
    :return: Fraction of graphs that are binary trees.
    """
    mode = sparse_mode(data.edge_index, edge_attr_pred)
    mode_counts = torch.bincount(mode, minlength=data.num_nodes)
    mode_counts[0] = 0

    metric = MeanMetric().to(edge_attr_pred.device)
    for graph in range(data.num_graphs):
        count = mode_counts[data.batch == graph]
        metric.update(torch.all((count == 0) | (count == 2)))
    return metric.compute().item()
