import os
import torch
import numpy as np
from scipy.sparse import coo_array
from torch_geometric.data import HeteroData
from sklearn.preprocessing import normalize
from torch_geometric.utils import from_scipy_sparse_matrix
from utils.utils import get_dataset
from data_preprocess.compute_lifting import (
    get_folder_from_dset,
    compute_and_save_lifting,
)


def define_basis_order(all_simplices):
    """
    Assign unique IDs to simplices for constructing connectivity structures.
    Parameters:
        all_simplices (list of lists of lists of ints):
                each list contains the list of simplices of a given dimension.
                each simplex is a list of integers representing the vertices.

    Returns:
        simplices_id_maps (list of dicts):
                each dict maps a simplex to its unique ID.
    """
    simplices_id_maps = []
    for simplices in all_simplices:
        s_id_map = {}
        current_id = 0
        for e in simplices:
            key = tuple(e)
            if key not in s_id_map:
                # tuples of one element are not hashable
                s_id_map[key] = e[0] if len(e) == 1 else current_id
                if len(e) > 1:
                    current_id += 1
        simplices_id_maps.append(s_id_map)
    return simplices_id_maps


def compute_relations(max_dim, all_simplices):
    """
    Compute connectivity elements for simplicial complexes up to max_dim.
    Parameters:
        max_dim (int): Maximum dimension of simplices to consider.
        all_simplices (list of lists of lists of ints):
                each list contains the list of simplices of a given dimension.
                each simplex is a list of integers representing the vertices.

    Returns:
        up_adjacencies (list of lists of sparse matrices):
                Up-adjacency matrices for each dimension.
        down_adjacencies (list of lists of sparse matrices):
                Down-adjacency matrices for each dimension.
        boundaries (list of sparse matrices):
                Boundary matrices for each dimension.
        coboundaries (list of sparse matrices):
                Coboundary matrices for each dimension.
    """
    simplices_id_maps = define_basis_order(all_simplices)
    up_adjacencies, down_adjacencies = {}, {}
    boundaries, coboundaries = {}, {}

    for dim in range(1, max_dim + 1):
        Bs = compute_face_maps(
            all_simplices[dim],
            simplices_id_maps[dim - 1],
            simplices_id_maps[dim],
            dim,
        )

        B = sum(Bs)
        BT = B.T

        down_adjs, up_adjs = compute_up_down_adjacencies(Bs, dim)

        # We need to transpose the relations because MessagePassing expects source -> target lists
        boundaries[dim] = BT
        coboundaries[dim - 1] = B
        up_adjacencies[dim - 1] = [adj.T for adj in up_adjs]
        down_adjacencies[dim] = [adj.T for adj in down_adjs]

    return up_adjacencies, down_adjacencies, boundaries, coboundaries


def compute_face_maps(simplex_list, lower_simplex_id_map, simplex_id_map, dim):
    """
    Compute boundary matrices for a given simplex dimension.
    Parameters:
        simplex_list (list of lists of ints):
                List of simplices of the current dimension.
        lower_simplex_id_map (dict):
                Mapping from lower-dimensional simplices to their IDs.
        simplex_id_map (dict):
                Mapping from current simplices to their IDs.
        dim (int): Dimension of the simplices.

    Returns:
        Bs (list of sparse matrices):
                Boundary matrices for the current dimension.
    """
    n_lower = len(lower_simplex_id_map)
    n_simplices = len(simplex_id_map)
    array_idx = [[] for _ in range(dim + 1)]

    for simplex in simplex_list:
        simplex_id = simplex_id_map[tuple(simplex)]
        for i in range(dim + 1):
            lower_face = np.concatenate((simplex[:i], simplex[i + 1 :]))
            lower_id = lower_simplex_id_map[tuple(lower_face)]
            array_idx[i].append([lower_id, simplex_id])

    Bs = [
        (
            coo_array(
                (
                    np.ones(len(idx)),
                    np.array(idx).T,
                ),
                shape=(n_lower, n_simplices),
            )
            if len(idx) > 0
            else coo_array(
                (n_lower, n_simplices),
            )
        )
        for i, idx in enumerate(array_idx)
    ]

    return Bs


def compute_up_down_adjacencies(B, dim):
    """
    Compute up and down adjacencies for a given dimension.
    Parameters:
        B (list of sparse matrices):
                Boundary matrices for the current dimension.
        dim (int): Dimension of the simplices.
    Returns:
        down_adjs (list of sparse matrices):
                Down-adjacency matrices for the current dimension.
        up_adjs (list of sparse matrices):
                Up-adjacency matrices for the current dimension.
    """

    down_adjs = [B[j].T @ B[i] for i in range(dim + 1) for j in range(dim + 1)]

    up_adjs = [B[j] @ B[i].T for i in range(dim + 1) for j in range(dim + 1) if i != j]

    # Process adjacencies (remove diagonals, ensure binary support)
    for adjs in [down_adjs, up_adjs]:
        for adj in adjs:
            adj.setdiag(0)  # Remove self-loops
            adj.eliminate_zeros()  # Clean up zero entries
            #adj.data[:] = 1  # Convert to binary support

    return down_adjs, up_adjs


def create_heterodata_object(
    dset: str, max_dim: int, feat_init: str = "zeros", root_dir="./data"
):
    """
    Create an object of the class HeteroData with all the info of the
    graph lifted into a directed SC. This will be the input for a
    HeteroConv neural network.

    Parameters:

        dset (str): dataset name
        max_dim (int): maximum dimension of simplices of interest (e.g., 2 for triangles)
        feat_init (str): Method for feature initialization
        root_dir (str): Root directory for the dataset

    Returns:
        data (HeteroData): HeteroData object containing the lifted graph
    """

    # Load graph
    folder = get_folder_from_dset(dset)

    dataset = get_dataset(name=dset, root_dir=root_dir)

    # Load pre-computed simplices
    if not os.path.exists(f"{root_dir}/{folder}/{folder}_all_simplices"):
        # Compute lifting of the graph
        compute_and_save_lifting(
            max_dim=max_dim, num_threads=8, dset=dset, path=root_dir, flagser=False
        )

    all_simplices = torch.load(f"{root_dir}/{folder}/{folder}_all_simplices")

    if len(all_simplices) < max_dim + 1:
        # The lifting saved in memory is up to an order of simplices that is not high enough,
        # so recompute it
        compute_and_save_lifting(
            max_dim=max_dim, num_threads=8, dset=dset, path=root_dir, flagser=False
        )
        all_simplices = torch.load(f"{root_dir}/{folder}/{folder}_all_simplices")

    all_edges_from_dataset = [[e[0].item(), e[1].item()] for e in dataset.edge_index.T]
    # Replace the edges from the original graph to ensure that they are listed in the correct order
    all_simplices[1] = all_edges_from_dataset

    (
        up_adjacencies,
        down_adjacencies,
        boundaries,
        coboundaries,
    ) = compute_relations(max_dim, all_simplices)

    # Create new data object
    data = HeteroData()

    # Add features to the HeteroData object
    if dset.startswith(
        (
            "traffic-anaheim",
            "traffic-barcelona",
            "traffic-chicago",
            "traffic-sioux-falls",
            "traffic-winnipeg",
            "electrical-circuits",
        )
    ):
        data["1"].x = torch.cat((dataset.equi_edge_attr, dataset.inv_edge_attr), dim=1)
        for i in range(0, max_dim + 1):
            if i != 1:
                if feat_init == "zeros":
                    data[str(i)].x = torch.zeros([len(all_simplices[i]), 1])
                elif feat_init == "random":
                    data[str(i)].x = torch.rand([len(all_simplices[i]), 1])
    else:
        ## Add simplex features for node-level tasks
        data["0"].x = dataset.x  # Node features are available
        # Initialize features for higher-order simplices
        for i in range(1, max_dim + 1):
            if feat_init == "zeros":
                data[str(i)].x = torch.zeros([len(all_simplices[i]), 1])
            elif feat_init == "random":
                data[str(i)].x = torch.rand([len(all_simplices[i]), 1])

    # save the computed adjacencies as edge_index in the HeteroData object
    for i in range(max_dim + 1):
        if i != max_dim:
            data[str(i), "c_a", str(i + 1)].edge_index = from_scipy_sparse_matrix(
                coboundaries[i]
            )[0]

            # The undirected up adjacency is the sum of all directed adjacencies
            data[str(i), "u_a", str(i)].edge_index = from_scipy_sparse_matrix(
                sum(up_adjacencies[i])
            )[0]

            for j, adj in enumerate(up_adjacencies[i]):
                data[str(i), f"u_{j}", str(i)].edge_index = from_scipy_sparse_matrix(
                    adj
                )[0]

        if i != 0:
            for j, adj in enumerate(down_adjacencies[i]):
                data[str(i), f"d_{j}", str(i)].edge_index = from_scipy_sparse_matrix(
                    adj
                )[0]

            # The undirected down adjacency is the sum of all directed adjacencies
            data[str(i), "d_a", str(i)].edge_index = from_scipy_sparse_matrix(
                sum(down_adjacencies[i])
            )[0]

            data[str(i), "b_a", str(i - 1)].edge_index = from_scipy_sparse_matrix(
                boundaries[i]
            )[0]

    return data


def save_heterodata_object(dset, max_dim, feat_init="zeros", root_dir="./data/"):
    """
    Save the generated HeteroData object.
    Parameters:
        dset (str): dataset name
        max_dim (int): maximum dimension of simplices of interest (e.g., 2 for triangles)
        feat_init (str): Method for feature initialization
        root_dir (str): Root directory for the dataset

    """
    print("*** Creating the HeteroData object ***")
    data = create_heterodata_object(dset, max_dim=max_dim, feat_init=feat_init)
    folder = get_folder_from_dset(dset)
    torch.save(data, f"{root_dir}/{folder}/{folder}_heterodata_{feat_init}")
