# adapted from https://github.com/johannbrehmer/ToyJetsShower/blob/master/run2DShower.py

import os
import pickle
import random
from pathlib import Path

import numpy as np
import torch

from src.data.ginkgo.inv_mass_ginkgo import Simulator


def simulate(root: Path, jet_type: str, n_samples: int, max_leaves: int):
    os.makedirs(root, exist_ok=True)
    os.makedirs(root / f"{jet_type}_{n_samples}_{max_leaves}", exist_ok=True)

    simulate_2d_shower(root, jet_type, n_samples, max_leaves)

    with open(root / f"{jet_type}_{n_samples}_{max_leaves}.pkl", "rb") as f:
        trees = pickle.load(f)
    os.remove(root / f"{jet_type}_{n_samples}_{max_leaves}.pkl")

    hierarchical_trees = [transform_to_hierarchical(tree) for tree in trees]

    for i, tree in enumerate(hierarchical_trees):
        torch.save(tree, root / f"{jet_type}_{n_samples}_{max_leaves}" / f"jet{i}.pt")


def simulate_2d_shower(root: Path, jet_type: str, n_samples: int, max_leaves: int):
    rate = torch.tensor([2.0, 2.0])
    if jet_type == "Wjets":
        simulator = Simulator(
            jet_p=torch.tensor([500.,400.]),
            pt_cut=torch.tensor(0.04),
            Delta_0=torch.tensor(0.4),
            M_hard=torch.tensor(80.),
            num_samples=n_samples,
            minLeaves=2,
            maxLeaves=max_leaves)
    elif jet_type == "QCDjets":
        jet_P = 500.
        jet_M = 400.
        jetdir = np.array([1, 1, 1])
        jetvec = jet_P * jetdir / np.linalg.norm(jetdir)
        jet4vec = np.concatenate(([np.sqrt(jet_P**2 + jet_M**2)], jetvec))
        simulator = Simulator(
            jet_p=jet4vec,
            pt_cut=torch.tensor(0.04),
            Delta_0=torch.tensor(0.4),
            M_hard=None,
            num_samples=n_samples,
            minLeaves=2,
            maxLeaves=max_leaves)
    elif jet_type == "TrellisMw300":
        rate = torch.tensor([10.0, 10.0])
        simulator = Simulator(
            jet_p=torch.tensor([0., 0.]),
            pt_cut=torch.tensor(0.08),
            Delta_0=torch.tensor(60.),
            M_hard=torch.tensor(300.),
            num_samples=n_samples,
            minLeaves=2,
            maxLeaves=max_leaves)
    elif jet_type == "TrellisMw01":
        rate = torch.tensor([2.2, 2.2])
        simulator = Simulator(
            jet_p=torch.tensor([0., 0.]),
            pt_cut=torch.tensor(0.006),
            Delta_0=torch.tensor(60.),
            M_hard=torch.tensor(300.),
            num_samples=n_samples,
            minLeaves=2,
            maxLeaves=max_leaves)
    elif jet_type == "TrellisMw01B":
        simulator = Simulator(
            jet_p=torch.tensor([0., 0.]),
            pt_cut=torch.tensor(0.008),
            Delta_0=torch.tensor(60.),
            M_hard=torch.tensor(0.1),
            num_samples=n_samples,
            minLeaves=2,
            maxLeaves=max_leaves)
    else:
        raise ValueError(f"Please choose a valid jet type (QCDjets, Wjets or Topjets)")

    jet_list = simulator(rate)
    simulator.save(jet_list, root, f"{jet_type}_{n_samples}_{max_leaves}")


def transform_to_hierarchical(data):
    tree = data["tree"]
    content = data["content"]

    n_nodes, _ = tree.shape
    parent_nodes = [0]
    leaf_nodes = []
    parent_reference = [0] * n_nodes

    node_queue = tree[0].tolist()
    while node_queue:
        random_index = random.randint(0, len(node_queue) - 1)
        node = node_queue.pop(random_index)
        left_child, right_child = tree[node]
        if left_child + right_child == -2:
            # if node is leaf node
            leaf_nodes.insert(0, node)
            continue
        # if node is parent node
        parent_nodes.insert(0, node)
        parent_reference[left_child] = node
        parent_reference[right_child] = node
        node_queue = node_queue + [left_child, right_child]

    node_order = leaf_nodes + parent_nodes
    order_map = dict(zip(node_order, range(n_nodes)))
    mapped_parents = [order_map[parent] for parent in parent_reference]
    parent_reference = [mapped_parents[node] for node in node_order]

    X_leaf = torch.tensor(content[leaf_nodes], dtype=torch.float)
    X_parent = torch.tensor(content[parent_nodes], dtype=torch.float)
    X = torch.cat([X_leaf, X_parent], dim=0)
    A = torch.tensor(parent_reference, dtype=torch.long)
    return {"X": X, "A": A}
