from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from torch_geometric.data import Batch
from torchmetrics import MeanMetric
from wandb import Image

from src.models.eval.likelihood_invM import (
    reevaluate_jet_logLH,
    split_logLH_with_stop_nonstop_prob,
)
from src.utils.sparse_utils import sparse_mode


def compute_llhs(
    edge_attr_pred: Float[Tensor, "n_edges"],
    data: Batch,
    pt_min_sqrt: float,
    qcd_rate: float,
) -> Tuple[
    List[Float[np.ndarray, "n_nodes"]], List[Float[np.ndarray, "n_nodes"]], List[int]
]:
    """Compute the log-likelihoods of the predicted and target jets.

    :param edge_attr_pred: The predicted edge attributes.
    :param data: The PyG batch object.
    :param pt_min_sqrt: The minimum pT for a jet.
    :param qcd_rate: The QCD rate.
    :return: The log-likelihoods of the predicted and target jets and the number of nodes in each
        graph.
    """
    jets_pred = make_jets(edge_attr_pred, data)
    jets_target = make_jets(data.edge_attr_target, data)
    graph_sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
    llhs_pred = []
    llhs_target = []
    n = []
    for jet_pred, jet_target, graph_size in zip(jets_pred, jets_target, graph_sizes):
        if jet_pred is None:
            continue
        llh_pred = compute_llh(jet_pred, pt_min_sqrt, qcd_rate)
        llh_target = compute_llh(jet_target, pt_min_sqrt, qcd_rate)
        llhs_pred.append(llh_pred)
        llhs_target.append(llh_target)
        n.append(graph_size)
    return llhs_pred, llhs_target, n


def compute_llhs_helper(
    edge_attr_pred: Float[Tensor, "n_edges"],
    data: Batch,
    pt_min_sqrt: float,
    qcd_rate: float,
) -> Tuple[
    List[Float[np.ndarray, "n_nodes"]], List[Float[np.ndarray, "n_nodes"]], List[int]
]:
    """Compute the log-likelihoods of the predicted and target jets.

    :param edge_attr_pred: The predicted edge attributes.
    :param data: The PyG batch object.
    :param pt_min_sqrt: The minimum pT for a jet.
    :param qcd_rate: The QCD rate.
    :return: The log-likelihoods of the predicted and target jets and the number of nodes in each
        graph.
    """
    jets_pred = make_jets(edge_attr_pred, data)
    jets_target = make_jets(data.edge_attr_target, data)
    graph_sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
    llhs_pred = []
    llhs_target = []
    n = []
    for jet_pred, jet_target, graph_size in zip(jets_pred, jets_target, graph_sizes):
        if jet_pred is None:
            llh_pred = -1e9
        else:
            llh_pred = compute_llh(jet_pred, pt_min_sqrt, qcd_rate).sum().item()
        llh_target = compute_llh(jet_target, pt_min_sqrt, qcd_rate).sum().item()
        llhs_pred.append(llh_pred)
        llhs_target.append(llh_target)
        n.append(graph_size)
    return llhs_pred, llhs_target, n


def compute_valid_llh_tree(llhs_pred: List[Float[np.ndarray, "n_nodes"]]) -> float:
    """Compute the fraction of valid binary trees in the predicted jets.

    :param llhs_pred: The predicted log-likelihoods.
    :return: The fraction of valid binary trees.
    """
    valid_hierarchy = MeanMetric()
    for llh in llhs_pred:
        valid_hierarchy.update(np.all(np.isfinite(llh)))
    return valid_hierarchy.compute()


def compute_valid_llh_parent(llhs_pred: List[Float[np.ndarray, "n_nodes"]]) -> float:
    """Compute the fraction of parent nodes that have exactly two children.

    :param llhs_pred: The predicted log-likelihoods.
    :return: The fraction of parent nodes that have exactly two children.
    """
    valid_parent = MeanMetric()
    for llh in llhs_pred:
        finite_llh_mask = np.isfinite(llh[llh != 0])
        valid = np.mean(finite_llh_mask)
        valid_parent.update(valid, np.sum(finite_llh_mask))
    return valid_parent.compute()


def compute_llh_fraction(
    llhs_pred: List[Float[np.ndarray, "n_nodes"]],
    llhs_target: List[Float[np.ndarray, "n_nodes"]],
) -> float:
    llh_fraction = MeanMetric()
    for llh_pred, llh_target in zip(llhs_pred, llhs_target):
        if sum(llh_pred) < -1e8 or sum(llh_target) < -1e8:
            continue
        llh_fraction.update(sum(llh_target) / sum(llh_pred))
    return llh_fraction.compute()


def compute_llh_mean(llhs_pred: List[Float[np.ndarray, "n_nodes"]]) -> float:
    llh_mean = MeanMetric()
    for llh_pred in llhs_pred:
        finite_llh_mask = np.isfinite(llh_pred[llh_pred != 0])
        finite_llh_mean = np.mean(llh_pred[llh_pred != 0][finite_llh_mask])
        llh_mean.update(finite_llh_mean, np.sum(finite_llh_mask))
    return llh_mean.compute()


def make_llh_plot(
    llhs_pred: List[Float[np.ndarray, "n_nodes"]],
    llhs_target: List[Float[np.ndarray, "n_nodes"]],
    n: List[int],
) -> Image:
    llhs = []
    for i in range(len(llhs_pred)):
        llh_pred = sum(llhs_pred[i])
        if llh_pred < -1e6:
            continue
        llh_target = sum(llhs_target[i])
        if llh_target < -1e6:
            continue
        llhs.append([llh_target, llh_pred, n[i]])
    if len(llhs) == 0:
        return None
    llhs = np.array(llhs).T
    fig, axes = plt.subplots(nrows=1, ncols=1)
    fig.set_size_inches(8, 7)

    cl1 = axes.scatter(llhs[0], llhs[1], c=llhs[2], marker="X", s=5)
    cb = fig.colorbar(cl1, ax=axes, fraction=0.15, shrink=1.0, aspect=20)
    cb.set_label(label="Number of Nodes", size=20)
    cb.ax.tick_params(labelsize=15)

    min_llh = np.min(llhs[:2, :])
    max_llh = np.max(llhs[:2, :])
    x = np.linspace(min_llh, max_llh, 1000)
    axes.plot(x, x, color="black", linestyle="--")

    axes.grid(which="both", axis="both", linestyle="--")
    axes.set_xlabel(r"$\log p(x, H_{Truth})$", fontsize=20)
    axes.set_ylabel(r"$\log p(x, H_{Generated})$", fontsize=20)
    image = Image(fig)
    plt.close(fig)
    return image


def make_dense_adjacency(
    parent_graph: Float[Tensor, "n_edges"],
):
    num_nodes = parent_graph.shape[0] + 1

    # Initialize the adjacency matrix
    adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.int32)

    # Fill in the adjacency matrix
    rows = torch.arange(parent_graph.shape[0])
    cols = parent_graph
    adj_matrix[rows, cols] = 1
    return adj_matrix


def make_jets(
    edge_attr: Float[Tensor, "n_edges"], data: Batch
) -> List[Dict[str, np.array]]:
    """Create a list of jets given the edge attributes and the PyG batch object.

    :param edge_attr: The edge attributes.
    :param data: The PyG batch object.
    :return: A list of jets.
    """
    parent = sparse_mode(data.edge_index, edge_attr)
    graph_shift = data.ptr[:-1][data.batch]
    parent = parent - graph_shift
    jets = []
    for graph in range(data.num_graphs):
        X_leaf = data.x[(~data.parent_mask) & (data.batch == graph)]
        parent_graph = parent[data.batch == graph][:-1]
        jets.append(make_jet(X_leaf, parent_graph))
    return jets


def make_jet(
    X_child: Float[Tensor, "n_leaves d"], y: Float[Tensor, "n_nodes"]
) -> Dict[str, np.array]:
    """Create a single jet given the node features of the children nodes and the predicted/target
    edges of the graph.

    :param X_child: The features of child nodes in the graph.
    :param y: The edges of the graph.
    :return: A dictionary containing the tree structure, node contents, and root ID of the created
        jet.
    """
    n_leaves, n_features = X_child.shape
    n_parents = n_leaves - 1
    n_nodes = n_leaves + n_parents
    root_id = n_nodes - 1
    edges = y.cpu().tolist()

    visited_parents = 0
    tree = -1 * np.ones((n_nodes, 2))
    node_queue = [root_id]
    while node_queue:
        node = node_queue.pop(0)
        children = find_children(edges, node)
        if len(children) == 0:
            continue
        if len(children) == 2:
            left_child, right_child = children
            node_queue += [left_child, right_child]
            tree[node, 0] = left_child
            tree[node, 1] = right_child
            visited_parents += 1
        else:
            return None
    if visited_parents != n_parents:
        return None
    tree = tree.astype(np.int64)

    X_parent = torch.zeros((n_parents, n_features), device=X_child.device)
    X = torch.cat((X_child, X_parent), axis=0)
    content = compute_parent_content(tree, X)
    return {"tree": tree, "content": content, "root_id": root_id}


def find_children(edges: List[int], node: int) -> List[int]:
    """Find the children of a given node in a jet.

    :param edges: The list of child-parent relationships.
    :param node: The current node to find children for.
    :return: A list of child node indices.
    """
    return [child for child, parent in enumerate(edges) if node == parent]


def compute_parent_content(
    tree: np.array, content: Float[Tensor, "n_leaves d"]
) -> np.array:
    """Compute the content of parent nodes in a jet by summing child node contents.

    :param tree: The tree structure of the jet.
    :param content: The features of the nodes in the jet.
    :return: The updated content with parent node features computed.
    """
    n_nodes, _ = content.shape
    n_leaves = (n_nodes + 1) // 2

    content_exists = list(range(n_leaves))
    parents_left = list(range(n_leaves, n_nodes))
    while parents_left:
        ready_parent = parents_left[
            np.where(np.all(np.isin(tree[parents_left], content_exists), axis=1))[0][0]
        ]
        left_child, right_child = tree[ready_parent]
        content[ready_parent] = content[left_child] + content[right_child]
        content_exists += [ready_parent]
        parents_left.remove(ready_parent)
    return content.cpu().numpy()


def compute_llh(
    jet: Dict[str, np.array], pt_min_sqrt: float, qcd_rate: float
) -> Float[np.ndarray, "n_nodes"]:
    """Compute the log-likelihood of a given jet configuration.

    :param jet: The jet configuration containing tree structure and node contents.
    :param pt_min_sqrt: Parameter of Ginkgo simulator.
    :param qcd_rate: Parameter of Ginkgo simulator.
    :return: The log-likelihood of the jet configuration for each node.
    """
    jet = reevaluate_jet_logLH(
        jet,
        delta_min=pt_min_sqrt**2,
        Lambda=qcd_rate,
        LambdaRoot=qcd_rate,
        split_log_LH=split_logLH_with_stop_nonstop_prob,
    )
    return jet["logLH"]


def compute_llh_w(
    jet: Dict[str, np.array], pt_min_sqrt: float, qcd_rate: float
) -> Float[np.ndarray, "n_nodes"]:
    """Compute the log-likelihood of a given jet configuration.

    :param jet: The jet configuration containing tree structure and node contents.
    :param pt_min_sqrt: Parameter of Ginkgo simulator.
    :param qcd_rate: Parameter of Ginkgo simulator.
    :return: The log-likelihood of the jet configuration for each node.
    """
    jet = reevaluate_jet_logLH(
        jet,
        delta_min=pt_min_sqrt**2,
        Lambda=qcd_rate,
        LambdaRoot=2 * qcd_rate,
        split_log_LH=split_logLH_with_stop_nonstop_prob,
    )
    return jet["logLH"]


def compute_llhs_helper_w(
    edge_attr_pred: Float[Tensor, "n_edges"],
    data: Batch,
    pt_min_sqrt: float,
    qcd_rate: float,
) -> Tuple[
    List[Float[np.ndarray, "n_nodes"]], List[Float[np.ndarray, "n_nodes"]], List[int]
]:
    """Compute the log-likelihoods of the predicted and target jets.

    :param edge_attr_pred: The predicted edge attributes.
    :param data: The PyG batch object.
    :param pt_min_sqrt: The minimum pT for a jet.
    :param qcd_rate: The QCD rate.
    :return: The log-likelihoods of the predicted and target jets and the number of nodes in each
        graph.
    """
    jets_pred = make_jets(edge_attr_pred, data)
    jets_target = make_jets(data.edge_attr_target, data)
    graph_sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
    llhs_pred = []
    llhs_target = []
    n = []
    for jet_pred, jet_target, graph_size in zip(jets_pred, jets_target, graph_sizes):
        if jet_pred is None:
            llh_pred = -1e9
        else:
            llh_pred = compute_llh_w(jet_pred, pt_min_sqrt, qcd_rate).sum().item()
        llh_target = compute_llh_w(jet_target, pt_min_sqrt, qcd_rate).sum().item()
        llhs_pred.append(llh_pred)
        llhs_target.append(llh_target)
        n.append(graph_size)
    return llhs_pred, llhs_target, n


def compute_llhs_w(
    edge_attr_pred: Float[Tensor, "n_edges"],
    data: Batch,
    pt_min_sqrt: float,
    qcd_rate: float,
) -> Tuple[
    List[Float[np.ndarray, "n_nodes"]], List[Float[np.ndarray, "n_nodes"]], List[int]
]:
    """Compute the log-likelihoods of the predicted and target jets.

    :param edge_attr_pred: The predicted edge attributes.
    :param data: The PyG batch object.
    :param pt_min_sqrt: The minimum pT for a jet.
    :param qcd_rate: The QCD rate.
    :return: The log-likelihoods of the predicted and target jets and the number of nodes in each
        graph.
    """
    jets_pred = make_jets(edge_attr_pred, data)
    jets_target = make_jets(data.edge_attr_target, data)
    graph_sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()
    llhs_pred = []
    llhs_target = []
    n = []
    for jet_pred, jet_target, graph_size in zip(jets_pred, jets_target, graph_sizes):
        if jet_pred is None:
            continue
        llh_pred = compute_llh_w(jet_pred, pt_min_sqrt, qcd_rate)
        llh_target = compute_llh_w(jet_target, pt_min_sqrt, qcd_rate)
        llhs_pred.append(llh_pred)
        llhs_target.append(llh_target)
        n.append(graph_size)
    return llhs_pred, llhs_target, n
