import os
import torch
import pickle
import requests
import numpy as np
import networkx as nx
from io import BytesIO
from zipfile import ZipFile
from torch_geometric.data import HeteroData
from scipy.sparse import csr_matrix, save_npz
from typing import List, Optional, Dict, Union
from torch_geometric.utils import from_scipy_sparse_matrix

# ============================================
# DATA PREPARATION
# ============================================


def compute_digraph_from_adj_matrix(adj):
    G = nx.DiGraph()
    rows, cols = np.nonzero(adj)
    edges = zip(rows, cols)
    G.add_edges_from(edges)
    return G


def compute_heterodata_size(data: HeteroData) -> int:
    """Compute the total memory size (in bytes) of all tensor attributes in HeteroData."""
    total_bytes = 0

    # Process node attributes.
    for node_type in data.node_types:
        node_data = data[node_type]
        for key, value in node_data.items():
            if isinstance(value, torch.Tensor):
                total_bytes += value.element_size() * value.numel()

    # Process edge attributes.
    for edge_type in data.edge_types:
        edge_data = data[edge_type]
        for key, value in edge_data.items():
            if isinstance(value, torch.Tensor):
                total_bytes += value.element_size() * value.numel()

    # Optionally, process any top-level attributes in data (if present).
    for key, value in data.items():
        # This may include keys like 'y' or others at the top level.
        if isinstance(value, torch.Tensor):
            total_bytes += value.element_size() * value.numel()

    return total_bytes


def convert_data():

    print(
        "Downloading data from %s",
        "https://zenodo.org/record/4290212/files/input_data.zip",
    )
    resp = requests.get("https://zenodo.org/record/4290212/files/input_data.zip")
    resp.raise_for_status()

    # Extract relevant files into memory
    with ZipFile(BytesIO(resp.content)) as archive:
        # Read directly without writing
        with archive.open("input_data/raw_spikes.npy") as raw_f:
            X = np.load(raw_f)
        with archive.open("input_data/stim_stream.npy") as stim_f:
            Y = np.load(stim_f)

    output = [[] for _ in range(8)]  # Initialize output for 8 categories
    T = [[] for _ in range(len(Y))]  # Initialize T with the length of Y

    # Organize raw spikes into bins
    for i in X:
        idx = int(i[0] / 200)
        if idx < len(T):  # Ensure index is within bounds
            T[idx].append([i[0] % 200, i[1]])
        else:
            print(f"Warning: Spike index {idx} is out of bounds and will be ignored.")

    # Categorize based on stimulus
    for i in range(len(Y)):
        if Y[i] < len(output):  # Ensure index is within bounds
            output[Y[i]].append(T[i])
        else:
            print(
                f"Warning: Stimulus index {Y[i]} is out of bounds and will be ignored."
            )

    # Convert `output` to a NumPy array with `dtype=object` to handle nested structures
    output = np.array(output, dtype=object)

    # Save the output
    np.save("../brain/spike_trains.npy", output, allow_pickle=True)


def extract_spiking_activity(
    spiketrains, num_neurons, bin_number, timebin, start_time, spike_gid_shift
):
    """
    Extracts spiking activity as a list of sparse matrices (one per experiment).

    Parameters:
    - spiketrains: List of spike events (each event contains [time, neuron_id]).
    - num_neurons: Total number of neurons.
    - bin_number: Total number of time bins.
    - timebin: Bin size in ms.
    - start_time: Start time of the recording.
    - spike_gid_shift: ID offset for neuron indexing.

    Returns:
    - A list of sparse (num_neurons x bin_number) matrices.
    """
    spiking_matrices = []

    for activation_class in range(len(spiketrains)):
        for experiment in spiketrains[activation_class]:
            experiment = np.array(experiment)  # Convert to NumPy for slicing

            row_indices = []
            col_indices = []
            data = []

            for t in range(bin_number):
                # Find neurons that spiked in this time bin
                spikers = (
                    experiment[
                        (experiment[:, 0] > t * timebin + start_time)
                        & (experiment[:, 0] <= (t + 1) * timebin + start_time)
                    ][:, 1]
                    - spike_gid_shift
                ).astype(int)

                # Store nonzero values in sparse format
                for neuron_id in spikers:
                    if 0 <= neuron_id < num_neurons:
                        row_indices.append(neuron_id)
                        col_indices.append(t)
                        data.append(1)  # Binary spike event

            # Create sparse matrix for this experiment
            sparse_matrix = csr_matrix(
                (data, (row_indices, col_indices)), shape=(num_neurons, bin_number)
            )
            spiking_matrices.append(sparse_matrix)

    # Save all sparse matrices
    for i, matrix in enumerate(spiking_matrices):
        save_npz(
            f"./binary_functions_{bin_number}/spiking_activity_experiment_{i}.npz",
            matrix,
        )

    return spiking_matrices


# ============================================
# DAC LIFTING
# ============================================


def check_all_simplices_firing(
    bf: csr_matrix,
    simplices: List[List[int]],
    node_mapping: Optional[Dict[int, int]] = None,
) -> torch.Tensor:
    """
    Computes the firing state for each simplex (clique) by performing an AND
    operation across the neurons in that simplex. A value of 1 indicates that
    all neurons in the simplex fired at that time bin.

    Parameters
    ----------
    bf : csr_matrix
        Sparse binary matrix of shape (num_neurons, num_timebins).
    simplices : List[List[int]]
        List of cliques/simplices, where each simplex is a list of neuron indices.
    node_mapping : dict, optional
        If provided, maps each neuron index in a simplex to a different index.

    Returns
    -------
    torch.Tensor
        Feature matrix of shape (num_simplices, num_timebins) with type float32.
        Each entry is 1 if all neurons in the simplex fired, 0 otherwise.
    """
    num_simplices = len(simplices)
    num_timebins = bf.shape[1]

    # Preallocate result array (using unsigned 8-bit integer for binary values)
    firing_matrix = np.zeros((num_simplices, num_timebins), dtype=np.uint8)

    for i, simplex in enumerate(simplices):
        # If a mapping is provided, remap indices; otherwise use the simplex as is.
        indices = (
            [node_mapping[node] for node in simplex]
            if node_mapping is not None
            else simplex
        )
        # For binary matrices the minimum over the rows is equivalent to an AND operation.
        firing_matrix[i, :] = bf[indices, :].min(axis=0).toarray().flatten()

    return torch.tensor(firing_matrix, dtype=torch.float32)


def lift_spiking_signals(
    bf: csr_matrix, dfg: List[List[int]], node_mapping: Optional[Dict[int, int]] = None
) -> Dict[int, Union[csr_matrix, torch.Tensor]]:
    """
    Lifts spiking signals from individual neurons to higher-order simplices (cliques).
    If a node mapping is provided, the binary firing matrix is first relabeled accordingly.

    Parameters
    ----------
    bf : csr_matrix
        Sparse binary matrix of spiking activity (shape: num_neurons x num_timebins).
    dfg : List[List[int]]
        List of simplices (cliques). The 0-th entry is assumed to represent single neurons,
        while higher entries represent cliques of increasing order.
    node_mapping : dict, optional
        Mapping from original neuron indices to new indices. If provided, only the rows
        corresponding to keys in this mapping are used and are relabeled using that mapping.

    Returns
    -------
    Dict[int, Union[csr_matrix, torch.Tensor]]
        A dictionary where key "0" maps to the (possibly relabeled) spiking matrix, and keys >= 1
        map to feature matrices (as torch.Tensor) for higher-order simplices.
    """

    if node_mapping is not None:
        # Only keep the rows corresponding to the keys of node_mapping.
        # 'keys' are the original neuron indices we wish to keep.
        keys = sorted(node_mapping.keys())
        # Create a filtered firing matrix that only contains these rows.
        bf_filtered = bf[keys, :]

        # For each row in bf_filtered, its local index (0,1,2,...) corresponds to keys[i].
        # Build an array that maps the local row index to the new label.
        local_to_new = np.array([node_mapping[r] for r in keys])

        # Now extract the nonzero indices from the filtered matrix.
        row_idxs, col_idxs = bf_filtered.nonzero()
        # For each row index (in the filtered matrix), map it to its new label.
        remapped_rows = local_to_new[row_idxs]

        # Construct the new relabeled sparse matrix.
        bf_relabelled = csr_matrix(
            (bf_filtered.data, (remapped_rows, col_idxs)),
            shape=(max(local_to_new) + 1, bf.shape[1]),
        )
    else:
        bf_relabelled = bf
        # No mapping was provided; no inverse mapping is needed.

    # Initialize the features dictionary.
    feat_dfg: Dict[str, Union[torch.Tensor]] = {
        "0": torch.tensor(bf_relabelled.todense(), dtype=torch.float32)
    }

    # For each higher-order simplex, compute the feature matrix.

    for dim, simplex in enumerate(dfg[1:], start=1):
        # Note: We pass the (possibly relabeled) binary matrix and an inverse mapping if needed.
        # Here, if a mapping was provided, we need the inverse mapping from new -> original.
        if node_mapping is not None:
            inverse_mapping = {new: old for old, new in node_mapping.items()}
        else:
            inverse_mapping = None
        feat_dfg[str(dim)] = check_all_simplices_firing(
            bf, simplex, node_mapping=inverse_mapping
        )

    return feat_dfg


# ============================================
# TOPOLOGY UTILS
# ============================================


def gen_relabelled_subgraph_from_set(G, neurons_set):

    G_sub = G.subgraph(neurons_set).copy()
    original_labels = list(G_sub.nodes())
    new_labels = list(range(len(original_labels)))
    neuron_relabelling = dict(zip(original_labels, new_labels))
    G_sub = nx.relabel_nodes(G_sub, neuron_relabelling)

    return G_sub, neuron_relabelling


def sort_dfg(dfg):
    for dim in range(1, len(dfg)):
        dfg[dim] = sorted(dfg[dim])
    return dfg


def count_cliques(dfg):
    # Count simplices
    clique_counts = {dim: len(dfg[dim]) for dim in range(len(dfg))}
    for dim, count in clique_counts.items():
        print(f"Number of {dim}-simplices: {count}")
    return clique_counts


def percent_of_zero_cliques(dfg_feat, clique_counts):
    for dim in range(1, len(dfg_feat)):
        dfg_feat_d = dfg_feat.get(str(dim), [])  # Get feature list safely
        clique_count_d = clique_counts.get(dim, 0)  # Get simplex count safely
        if clique_count_d > 0:
            percent = (sum(any(feat) for feat in dfg_feat_d) / clique_count_d) * 100
        else:
            percent = -1

        print(f"% Non-zero cliques at dim {dim}: {percent:.2f}")


# ============================================
# HETERODATA OBJECT CREATION
# ============================================


def create_hetero_data(
    label,
    dfg_feat,
    max_dfg_dim,
    coboundaries,
    up_adjacencies,
    down_adjacencies,
    boundaries,
):
    """
    Constructs a HeteroData object with hierarchical relations.

    Args:
        label (Tensor): Target labels.
        dfg_feat (dict): Feature dictionary mapping dimension to feature tensors.
        max_dfg_dim (int): Maximum dimension of the structure.
        coboundaries (list): List of sparse matrices for coboundary relations.
        up_adjacencies (list): List of lists of sparse matrices for upward adjacency relations.
        down_adjacencies (list): List of lists of sparse matrices for downward adjacency relations.
        boundaries (list): List of sparse matrices for boundary relations.

    Returns:
        HeteroData: A heterogeneous graph structure.
    """
    data = HeteroData()
    data.y = label

    # Assign node features
    for key, feat in dfg_feat.items():
        data[str(key)].x = feat

    # Assign hierarchical relations
    for i in range(max_dfg_dim + 1):
        str_i = str(i)
        str_i_next = str(i + 1) if i != max_dfg_dim else None
        str_i_prev = str(i - 1) if i != 0 else None

        # Coboundary and upward adjacency relations
        if str_i_next:
            if coboundaries[i].nnz > 0:  # Avoid empty relations
                data[str_i, "c_a", str_i_next].edge_index = from_scipy_sparse_matrix(
                    coboundaries[i]
                )[0]
            if sum(up_adjacencies[i]).nnz > 0:
                data[str_i, "u_a", str_i].edge_index = from_scipy_sparse_matrix(
                    sum(up_adjacencies[i])
                )[0]

            for j, adj in enumerate(up_adjacencies[i]):
                if adj.nnz > 0:
                    data[str_i, f"u_{j}", str_i].edge_index = from_scipy_sparse_matrix(
                        adj
                    )[0]

        # Boundary and downward adjacency relations
        if str_i_prev:
            for j, adj in enumerate(down_adjacencies[i]):
                if adj.nnz > 0:
                    data[str_i, f"d_{j}", str_i].edge_index = from_scipy_sparse_matrix(
                        adj
                    )[0]

            if sum(down_adjacencies[i]).nnz > 0:
                data[str_i, "d_a", str_i].edge_index = from_scipy_sparse_matrix(
                    sum(down_adjacencies[i])
                )[0]

            if boundaries[i].nnz > 0:
                data[str_i, "b_a", str_i_prev].edge_index = from_scipy_sparse_matrix(
                    boundaries[i]
                )[0]

    return data


# ============================================
# SAVE AND LOAD DATA
# ============================================


def save_tensor(tensor, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(tensor, path)


def save_pickle(obj, save_path, filename):
    """
    Saves an object as a pickle file in the specified directory.

    Args:
        obj (any): The object to be saved (e.g., dictionary, list, etc.).
        save_path (str): The directory where the file should be saved.
        filename (str): The name of the pickle file (without extension).

    Returns:
        str: The full path of the saved file.
    """
    # Ensure the directory exists
    os.makedirs(save_path, exist_ok=True)

    # Construct full file path
    file_path = os.path.join(save_path, f"{filename}.pkl")

    # Save the object as a pickle file
    with open(file_path, "wb") as f:
        pickle.dump(obj, f)

    return file_path  # Return the file path for confirmation/logging


def extract_experiment_boundaries(spike_trains_path):

    if not os.path.exists(spike_trains_path):
        convert_data()  # Ensure this function creates spike_trains.npy

    # Load spike trains
    spiketrains = np.load(spike_trains_path, allow_pickle=True, encoding="latin1")

    # Extract experiment boundaries
    exp_group_sizes = [len(group) for group in spiketrains]
    print("Structure of spike_trains:", exp_group_sizes)
    exp_group_boundaries = np.cumsum(exp_group_sizes)

    return exp_group_boundaries


def gen_binary_functions_data(
    spike_trains_path,
    total_neurons=31346,
    bin_number=2,
    timebin=25,
    start_time=10,
    spike_gid_shift=62693,
):

    if not os.path.exists(spike_trains_path):
        convert_data()  # Ensure this function creates spike_trains.npy

    spiketrains = np.load(spike_trains_path, allow_pickle=True, encoding="latin1")
    print("Spike trains loaded successfully!")
    print("Structure of spike_trains:", [len(group) for group in spiketrains])

    # Extract sparse matrices
    spiking_matrices = extract_spiking_activity(
        spiketrains, total_neurons, bin_number, timebin, start_time, spike_gid_shift
    )

    print(f"Extracted {len(spiking_matrices)} sparse matrices (one per experiment).")

    pass
