import os
import pickle
import random
from pathlib import Path

import numpy as np
import torch

from src.data.ginkgo.inv_mass_ginkgo import Simulator


def simulate(
    root: Path,
    n_samples: int,
    jet_type: str,
    min_leaves: int,
    max_leaves: int,
    w_rate: float,
    qcd_rate: float,
    qcd_mass: float,
    pt_min_sqrt: float,
):
    os.makedirs(root, exist_ok=True)
    dataset_name = get_dataset_name(
        n_samples, jet_type, min_leaves, max_leaves, w_rate, qcd_rate, qcd_mass, pt_min_sqrt
    )
    os.makedirs(root / dataset_name, exist_ok=True)

    simulate_inv_mass_ginkgo(
        root, n_samples, jet_type, min_leaves, max_leaves, w_rate, qcd_rate, qcd_mass, pt_min_sqrt
    )

    with open(root / f"{dataset_name}.pkl", "rb") as f:
        trees = pickle.load(f)
    os.remove(root / f"{dataset_name}.pkl")

    hierarchical_trees = [transform_to_hierarchical(tree) for tree in trees]

    for i, tree in enumerate(hierarchical_trees):
        torch.save(tree, root / dataset_name / f"jet{i}.pt")


def get_dataset_name(
    n_samples: int,
    jet_type: str,
    min_leaves: int,
    max_leaves: int,
    w_rate: float,
    qcd_rate: float,
    qcd_mass: float,
    pt_min_sqrt: float,
):
    qcd_rate = str(round(qcd_rate, 1))
    pt_min_sqrt = round(pt_min_sqrt, 1)
    if jet_type == "W":
        w_rate = round(w_rate, 1)
        extension = f"{pt_min_sqrt}-{qcd_rate}-{w_rate}"
    else:
        qcd_mass = round(qcd_mass, 1)
        extension = f"{pt_min_sqrt}-{qcd_rate}-{qcd_mass}"
    return f"{jet_type}-{n_samples}-{min_leaves}-{max_leaves}-{extension}"


def simulate_inv_mass_ginkgo(
    root: Path,
    n_samples: int,
    jet_type: str,
    min_leaves: int,
    max_leaves: int,
    w_rate: float,
    qcd_rate: float,
    qcd_mass: float,
    pt_min_sqrt: float,
):
    if jet_type == "W":
        rate = torch.tensor([w_rate, qcd_rate])
        M2start = torch.tensor(80.0**2)
    elif jet_type == "QCD":
        rate = torch.tensor([qcd_rate, qcd_rate])
        M2start = torch.tensor(qcd_mass**2)
    else:
        raise ValueError("Choose a valid jet type between W or QCD")

    jetM = np.sqrt(M2start.numpy())
    jetdir = np.array([1, 1, 1])
    jetP = 400.0
    jetvec = jetP * jetdir / np.linalg.norm(jetdir)
    jet4vec = np.concatenate(([np.sqrt(jetP**2 + jetM**2)], jetvec))
    simulator = Simulator(
        jet_p=jet4vec,
        pt_cut=torch.tensor(pt_min_sqrt**2),
        Delta_0=M2start,
        M_hard=jetM,
        num_samples=n_samples,
        minLeaves=min_leaves,
        maxLeaves=max_leaves,
    )
    jet_list = simulator(rate)
    dataset_name = get_dataset_name(
        n_samples, jet_type, min_leaves, max_leaves, w_rate, qcd_rate, qcd_mass, pt_min_sqrt
    )
    simulator.save(jet_list, root, dataset_name)


def transform_to_hierarchical(data):
    tree = data["tree"]
    content = data["content"]

    n_nodes, _ = tree.shape
    parent_nodes = [0]
    leaf_nodes = []
    parent_reference = [0] * n_nodes

    node_queue = tree[0].tolist()
    while node_queue:
        random_index = random.randint(0, len(node_queue) - 1)
        node = node_queue.pop(random_index)
        left_child, right_child = tree[node]
        if left_child + right_child == -2:
            # if node is leaf node
            leaf_nodes.insert(0, node)
            continue
        # if node is parent node
        parent_nodes.insert(0, node)
        parent_reference[left_child] = node
        parent_reference[right_child] = node
        node_queue = node_queue + [left_child, right_child]

    node_order = leaf_nodes + parent_nodes
    order_map = dict(zip(node_order, range(n_nodes)))
    mapped_parents = [order_map[parent] for parent in parent_reference]
    parent_reference = [mapped_parents[node] for node in node_order]

    X_leaf = torch.tensor(content[leaf_nodes], dtype=torch.float)
    X_parent = torch.tensor(content[parent_nodes], dtype=torch.float)
    X = torch.cat([X_leaf, X_parent], dim=0)
    A = torch.tensor(parent_reference, dtype=torch.long)
    return {"X": X, "A": A}
