import os
import json
import random
import numpy as np
from tqdm import tqdm
from typing import Dict
from scipy.sparse import load_npz
from data_preprocess.compute_lifting import compute_lifting
from data_preprocess.compute_connectivity import compute_relations

from src.brain_topo_decoding.brain_utils import (
    compute_digraph_from_adj_matrix,
    compute_heterodata_size,
    lift_spiking_signals,
    gen_relabelled_subgraph_from_set,
    save_tensor,
    sort_dfg,
    count_cliques,
    percent_of_zero_cliques,
    extract_experiment_boundaries,
    gen_binary_functions_data,
    create_hetero_data,
    save_pickle,
)

experiment_memory_usage: Dict[int, float] = {}


def create_dac_from_sample(
    component: int,
    vol_subgraph: Dict[str, Dict],
    label: int,
    graph,
    max_dfg_dim: int,
    exp_id: int,
    save_root: str,
    dacs_path: str,
) -> None:
    """Create and persist a single data sample for a given experiment.

    Parameters
    ----------
    component
        Index of the volumetric sample.
    vol_subgraph
        Dictionary holding GIDs for every neighbourhood of a radius bucket.
    label
        Experiment‑group label (``np.searchsorted`` into cumulative boundaries).
    graph
        Full brain connectome as a directed graph.
    max_dfg_dim
        Maximum simplex dimension for the direceted flag complex.
    exp_id
        Index of the spike experiment.
    save_root
        Root directory where everything is written to.
    """

    # -----------------------------------------------------------------
    # 1. Load binary functions (spiking activity) for the experiment
    # -----------------------------------------------------------------
    bfi = load_npz(f"./binary_functions/spiking_activity_experiment_" f"{exp_id}.npz")

    # -----------------------------------------------------------------
    # 2. Sample and relabel subgraph
    # -----------------------------------------------------------------
    gids = vol_subgraph[str(component)]["gids"]
    G_sub, relabelling = gen_relabelled_subgraph_from_set(graph, gids)

    # -----------------------------------------------------------------
    # 3. Compute Dynamical Activity Complex (DAC) + stats
    # -----------------------------------------------------------------
    dfg = sort_dfg(compute_lifting(G_sub, max_dfg_dim))
    clique_counts = count_cliques(dfg)
    dfg_feat = lift_spiking_signals(bfi, dfg, relabelling)
    percent_of_zero_cliques(dfg_feat, clique_counts)

    # -----------------------------------------------------------------
    # 5. Build relations for HeteroData
    # -----------------------------------------------------------------
    up_adj, down_adj, boundaries, coboundaries = compute_relations(
        max_dfg_dim, dfg, signed=False
    )

    data = create_hetero_data(
        label=label,
        dfg_feat=dfg_feat,
        max_dfg_dim=max_dfg_dim,
        coboundaries=coboundaries,
        up_adjacencies=up_adj,
        down_adjacencies=down_adj,
        boundaries=boundaries,
    )

    # -----------------------------------------------------------------
    # 6. Save objects
    # -----------------------------------------------------------------
    mem_mb = compute_heterodata_size(data) / (1024 * 1024)
    experiment_memory_usage[exp_id] = mem_mb
    print(f"Experiment {exp_id} – MEM footprint: {mem_mb:.2f} MB")

    # Saving heterodata object
    save_tensor(data, f"{save_root}/{dacs_path}/{int(component)}_dac_" f"{exp_id}")
    # Save sampled subgraph from volumetric sample
    save_tensor(G_sub, f"{save_root}/subgraphs/subgraph_{exp_id}")
    # Save directed flag complex from the sampled subgraph
    save_tensor(dfg, f"{save_root}/dfg/dfg_{exp_id}")
    # Save neuron relabelling from original -> sampled subgraph
    save_pickle(
        relabelling,
        f"{save_root}/relabellings",
        f"relabelling_{exp_id}.pkl",
    )

    return


if __name__ == "__main__":

    # Parameters --------------------------------------------------------------
    ADJ_MATRIX_PATH = "../mc2.npz"  # Path to the adjacency matrix
    MAX_DFG_DIM = 5  # Maximum dimension of the directed flag complex lifting
    BIN_NUMBER = 2  # Number of time bins
    SAMPLE = "325 um@4"  # Sample name
    SPIKE_GID_SHIFT = 62693  # Shift for the neuron GIDs
    N_NEIGHBOURHOODS = 25  # Number of neighbourhoods to sample
    M_SAMPLES_PER_EXP = 1  # Number of samples per experiment
    # Data Paths --------------------------------------------------------------
    SPIKE_TRAINS_PATH = "../../brain/spike_trains.npy"
    BINARY_FUNCTIONS_DIR = "../../brain/binary_functions"
    SAVE_ROOT = "./subgraph_325um"
    SAVE_DACS = f"{SAVE_ROOT}/dacs"
    NEIGHBOURHOOD_SAMPLES_PATH = "volumetric_subtribes.json"

    # -------------------------------------------------------------------------

    # -------------------------------------------------------------------------
    # 1. Load volumetric data and shift GIDs
    # -------------------------------------------------------------------------
    with open(NEIGHBOURHOOD_SAMPLES_PATH, "r") as file:
        vol_data = json.load(file)

    for vol_sample, components in vol_data.items():
        for comp_key, comp_data in components.items():
            vol_data[vol_sample][comp_key]["gids"] = (
                np.array(comp_data["gids"]) - SPIKE_GID_SHIFT
            )

    # -------------------------------------------------------------------------
    # 2. Ensure binary‑functions directory exists
    # -------------------------------------------------------------------------
    if not os.path.exists(BINARY_FUNCTIONS_DIR):
        gen_binary_functions_data(SPIKE_TRAINS_PATH, bin_number=BIN_NUMBER)

    # -------------------------------------------------------------------------
    # 3. Build full brain graph once
    # -------------------------------------------------------------------------
    adj = load_npz(ADJ_MATRIX_PATH).toarray().astype(int)
    G = compute_digraph_from_adj_matrix(adj)

    # -------------------------------------------------------------------------
    # 4. Determine stimuli experiment boundaries
    # -------------------------------------------------------------------------
    exp_boundaries = extract_experiment_boundaries(SPIKE_TRAINS_PATH)
    n_experiments = exp_boundaries[-1] + 1

    # -------------------------------------------------------------------------
    # 5. Iterate over experiments and create samples
    # -------------------------------------------------------------------------
    for exp_id in tqdm(range(n_experiments), desc="Experiments"):
        # Draw neighbourhood IDs *without* replacement for each experiment
        neighbourhood_ids = random.sample(range(N_NEIGHBOURHOODS), M_SAMPLES_PER_EXP)

        label = int(np.searchsorted(exp_boundaries, exp_id))

        for component in neighbourhood_ids:
            create_dac_from_sample(
                component=component,
                vol_subgraph=vol_data[SAMPLE],
                label=label,
                graph=G,
                max_dfg_dim=MAX_DFG_DIM,
                exp_id=exp_id,
                save_root=SAVE_ROOT,
                dacs_path=SAVE_DACS,
            )
