import os
import json
import numpy as np
from typing import Dict
from scipy.sparse import load_npz
from torch_geometric.data import HeteroData
from torch_geometric.utils import from_scipy_sparse_matrix
from data_preprocess.compute_lifting import compute_lifting
from data_preprocess.compute_connectivity import compute_relations

from src.brain_topo_decoding.brain_utils import (
    compute_digraph_from_adj_matrix,
    compute_heterodata_size,
    lift_spiking_signals,
    gen_relabelled_subgraph_from_set,
    save_tensor,
    sort_dfg,
    count_cliques,
    extract_experiment_boundaries,
    gen_binary_functions_data,
    save_pickle,
)

experiment_memory_usage: Dict[int, float] = {}


def process_exp_in_vol_sample(
    exp_id: int, label: int, dfg, cliques_count, relabelling, save_root
):

    bfi = load_npz(f"./binary_functions/spiking_activity_experiment_" f"{exp_id}.npz")

    # Lift the spiking signals.
    dfg_feat = lift_spiking_signals(bfi, dfg, relabelling)

    # Non-zero cliques percent
    for dim in range(1, len(dfg)):
        dfg_feat_d = dfg_feat.get(str(dim), [])  # Safely get feature list
        cliques_d = cliques_count.get(dim, 0)  # Safely get simplex count
        percent = (
            (sum(any(feat) for feat in dfg_feat_d) / cliques_d) * 100
            if cliques_d > 0
            else -1
        )
        print(
            f"Experiment {exp_id} - % of non-zero cliques at dim {dim}:"
            f" {percent:.2f}"
        )

    # Build the heterogeneous data structure.
    data = HeteroData()
    data.y = label

    for key, feat in dfg_feat.items():
        data[str(key)].x = feat

    total_bytes = compute_heterodata_size(data)
    size_in_mb = total_bytes / (1024 * 1024)
    print(f"Heterodata size: {size_in_mb:.2f} MB")

    # Write the memory usage to a log file, regardless of whether we save the experiment.
    log_file = f"{save_root}/experiment_memory_usage.txt"
    with open(log_file, "a") as log:
        log.write(f"Experiment {exp_id}: {size_in_mb:.2f} MB\n")

    experiment_memory_usage[exp_id] = size_in_mb

    print(f"Experiment {exp_id} SAVED - Memory: {size_in_mb:.2f} MB")

    save_tensor(data, f"{save_root}/dynamics/dynamic" f"{exp_id}")
    return


if __name__ == "__main__":

    # -------------------------------------------------------------------------
    # Parameters (customize as needed)
    ADJ_MATRIX_PATH = "../mc2.npz"  # Path to the adjacency matrix
    MAX_DFG_DIM = 5  # Maximum dimension of the directed flag complex lifting
    BIN_NUMBER = 2  # Number of time bins
    SPIKE_GID_SHIFT = 62693  # Shift for the neuron GIDs
    MAX_DIM_DFC = 5
    RADIUS = "175 um"
    COMPONENT = "8"

    # -------------------------------------------------------------------------
    # Define paths
    SPIKE_TRAINS_PATH = "../../brain/spike_trains.npy"
    BINARY_FUNCTIONS_DIR = "../../brain/binary_functions"
    SAVE_ROOT = "./8_175um"
    VOL_SAMPLES_PATH = "volumetric_components.json"

    # -------------------------------------------------------------------------
    # 1. Load volumetric data and shift GIDs
    # -------------------------------------------------------------------------
    with open(VOL_SAMPLES_PATH, "r") as file:
        vol_data = json.load(file)

    for vol_sample, components in vol_data.items():
        for comp_key, comp_data in components.items():
            vol_data[vol_sample][comp_key]["gids"] = (
                np.array(comp_data["gids"]) - SPIKE_GID_SHIFT
            )

    # -------------------------------------------------------------------------
    # 2. Ensure binary‑functions directory exists
    # -------------------------------------------------------------------------
    if not os.path.exists(BINARY_FUNCTIONS_DIR):
        gen_binary_functions_data(SPIKE_TRAINS_PATH, bin_number=BIN_NUMBER)

    # -------------------------------------------------------------------------
    # 3. Build structural volumetric sample subgraph
    # -------------------------------------------------------------------------
    adj = load_npz(ADJ_MATRIX_PATH).toarray().astype(int)
    graph = compute_digraph_from_adj_matrix(adj)
    vol_sample = vol_data[RADIUS][COMPONENT]["gids"]
    G_sub, relabelling = gen_relabelled_subgraph_from_set(graph, vol_sample)

    save_pickle(relabelling, SAVE_ROOT, "relabelling.pkl")

    # -----------------------------------------------------------------
    # 4. Compute Dynamical Activity Complex (DAC) + stats
    # -----------------------------------------------------------------
    dfg = sort_dfg(compute_lifting(G_sub, MAX_DFG_DIM))
    clique_counts = count_cliques(dfg)

    save_pickle(dfg, SAVE_ROOT, f"dfg.pkl")
    # -----------------------------------------------------------------
    # 5. Build and save relations for HeteroData
    # -----------------------------------------------------------------
    up_adj, down_adj, boundaries, coboundaries = compute_relations(
        MAX_DFG_DIM, dfg, signed=False
    )
    relations_dict = {}

    for i in range(MAX_DIM_DFC + 1):
        if i != MAX_DIM_DFC:
            relations_dict[(str(i), "c_a", str(i + 1))] = from_scipy_sparse_matrix(
                coboundaries[i]
            )[0]
            relations_dict[(str(i), "u_a", str(i))] = from_scipy_sparse_matrix(
                sum(up_adj[i])
            )[0]
            for j, adj in enumerate(up_adj[i]):
                relations_dict[(str(i), f"u_{j}", str(i))] = from_scipy_sparse_matrix(
                    adj
                )[0]

        if i != 0:
            for j, adj in enumerate(down_adj[i]):
                relations_dict[(str(i), f"d_{j}", str(i))] = from_scipy_sparse_matrix(
                    adj
                )[0]
            relations_dict[(str(i), "d_a", str(i))] = from_scipy_sparse_matrix(
                sum(down_adj[i])
            )[0]
            relations_dict[(str(i), "b_a", str(i - 1))] = from_scipy_sparse_matrix(
                boundaries[i]
            )[0]

    # Save as a PyTorch file
    save_tensor(relations_dict, f"{SAVE_ROOT}/relations.pt")

    # -------------------------------------------------------------------------
    # 5. Determine stimuli experiment boundaries
    # -------------------------------------------------------------------------
    exp_boundaries = extract_experiment_boundaries(SPIKE_TRAINS_PATH)
    n_experiments = exp_boundaries[-1] + 1

    for exp_id in range(n_experiments):
        process_exp_in_vol_sample(
            exp_id,
            np.searchsorted(exp_boundaries, exp_id),
            dfg,
            clique_counts,
            relabelling,
            SAVE_ROOT,
        )
