import os
import torch
import logging
import pickle
import torch_sparse
import numpy as np
from scipy.cluster.vq import kmeans
from torch_geometric.graphgym.register import register_loader
from torch_geometric.graphgym.config import cfg

from .datasets.ocb_dataset import OCBDataset
from .datasets.analogenie_dataset import AnalogGenieDataset, NAME_TO_ID_PINS, node2pins, ID_TO_NAME_NODES
from ..loader.split_generator import (prepare_splits,
                                         set_dataset_splits)


def log_loaded_dataset(dataset, format, name):
    logging.info(f"[*] Loaded dataset '{name}' from '{format}':")
    logging.info(f"  {dataset.data}")
    logging.info(f"  undirected: {dataset[0].is_undirected()}")
    logging.info(f"  num graphs: {len(dataset)}")

    total_num_nodes = 0
    if hasattr(dataset.data, 'num_nodes'):
        total_num_nodes = dataset.data.num_nodes
    elif hasattr(dataset.data, 'x'):
        total_num_nodes = dataset.data.x.size(0)
    logging.info(f"  avg num_nodes/graph: "
                 f"{total_num_nodes // len(dataset)}")
    logging.info(f"  num node features: {dataset.num_node_features}")
    logging.info(f"  num edge features: {dataset.num_edge_features}")
    if hasattr(dataset, 'num_tasks'):
        logging.info(f"  num tasks: {dataset.num_tasks}")

    if hasattr(dataset.data, 'y') and dataset.data.y is not None:
        if isinstance(dataset.data.y, list):
            # A special case for ogbg-code2 dataset.
            logging.info(f"  num classes: n/a")
        elif dataset.data.y.numel() == dataset.data.y.size(0) and \
                torch.is_floating_point(dataset.data.y):
            logging.info(f"  num classes: (appears to be a regression task)")
        else:
            logging.info(f"  num classes: {dataset.num_classes}")
    elif hasattr(dataset.data, 'train_edge_label') or hasattr(dataset.data, 'edge_label'):
        # Edge/link prediction task.
        if hasattr(dataset.data, 'train_edge_label'):
            labels = dataset.data.train_edge_label  # Transductive link task
        else:
            labels = dataset.data.edge_label  # Inductive link task
        if labels.numel() == labels.size(0) and \
                torch.is_floating_point(labels):
            logging.info(f"  num edge classes: (probably a regression task)")
        else:
            logging.info(f"  num edge classes: {len(torch.unique(labels))}")

    ## Show distribution of graph sizes.
    # graph_sizes = [d.num_nodes if hasattr(d, 'num_nodes') else d.x.shape[0]
    #                for d in dataset]
    # hist, bin_edges = np.histogram(np.array(graph_sizes), bins=10)
    # logging.info(f'   Graph size distribution:')
    # logging.info(f'     mean: {np.mean(graph_sizes)}')
    # for i, (start, end) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
    #     logging.info(
    #         f'     bin {i}: [{start:.2f}, {end:.2f}]: '
    #         f'{hist[i]} ({hist[i] / hist.sum() * 100:.2f}%)'
    #     )



@register_loader('custom_master_loader') # --> called in the load_dataset fn of graphgym
def load_dataset_master(format, name, dataset_dir):
    """
    Master loader that controls loading of all datasets, overshadowing execution
    of any default GraphGym dataset loader. Default GraphGym dataset loader are
    instead called from this function, the format keywords `PyG` and `OGB` are
    reserved for these default GraphGym loaders.

    Custom transforms and dataset splitting is applied to each loaded dataset.

    Args:
        format: dataset format name that identifies Dataset class
        name: dataset name to select from the class identified by `format`
        dataset_dir: path where to store the processed dataset

    Returns:
        PyG dataset object with applied perturbation transforms and data splits
    """
    
    data_dir = os.path.join(dataset_dir, name.split('_')[-1])
    print(f'*** Fetching data from {data_dir} \n')
    dataset = preformat(data_dir)

    if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
        dataset = prepro_add_supernodes(dataset)

    if hasattr(cfg.dataset, 'task_type') and cfg.dataset.task_type == "pin_prediction":
        dataset = preprocess_pins(dataset)

    dataset = preprocess_edges(dataset)

    # if hasattr(cfg.gt, 'conditional_gen') and cfg.gt.conditional_gen:
    #     dataset = preprocess_data_y(dataset)

    if cfg.gt.node_pruning == 2:
        # Add one node class for pruned nodes
        cfg.dataset.nnode_types = cfg.dataset.nnode_types + 1

    # Save marginal probability mass function
    compute_marginals(dataset)

    log_loaded_dataset(dataset, format, name)

    # If split indices are provided in the config, then split dataset accordingly,
    # otherwise, split randomly and save into a new binary file
    if hasattr(cfg, 'splits'):
        print('** Splitting dataset into train, val and test according to the given split indices **')
        set_dataset_splits(dataset, cfg.splits)
        # delattr(cfg, 'splits')
    else:
        print('** No split indices found, splitting dataset randomly between train, val and test **')
        if hasattr(dataset, 'split_idxs'):
            #### This is useful only if cfg.dataset.split_mode is set to `standard`, split idx are overwritten
            set_dataset_splits(dataset, dataset.split_idxs)
            delattr(dataset, 'split_idxs')
        prepare_splits(dataset)
        dump_splits([dataset.data[k] for k in ['train_graph_index', 'val_graph_index', 'test_graph_index']])

    return dataset


def compute_marginals(dataset):
    # Compute the marginal PMF
    node_types = torch.tensor([])
    # node_features = torch.tensor([])
    n_nodes, edge_ratio = 0, 0

    # Collect data from dataset
    for i, data in enumerate(dataset):
        node_types = torch.cat([node_types, data.x[:, 0]])
        # node_features = torch.cat([node_features, data.x[:, 1]])
        n_nodes += data.num_nodes
        edge_ratio += 0.5 * data.num_edges / (data.num_nodes * (data.num_nodes - 1))

    # When learning node number, add marginal sampling prob of dummy nodes
    if cfg.gt.node_pruning == 2:
        node_types = torch.cat([node_types, node_types.new_full((int(0.2 * len(node_types)),), fill_value=cfg.dataset.nnode_types)])
        n_nodes += int(0.2 * len(node_types))

    # Calculate node type marginals
    uniques, counts = torch.unique(node_types, return_counts=True)
    node_probs = counts / n_nodes
    node_type_pmf = [node_probs[uniques == i].item() if i in uniques else 0 
                    for i in range(cfg.dataset.nnode_types)]

    # # Compute conditional feature marginals for each node type --> We will simply sample from a uniform distrib, sorry for bad advice...
    # node_feature_pmf_given_type = []
    # for node_type in range(cfg.dataset.nnode_types):
    #     # Get features of nodes with this type
    #     type_mask = node_types == node_type
    #     if not type_mask.any():
    #         # If no nodes of this type exist, add zeros
    #         node_feature_pmf_given_type.append([0] * cfg.dataset.nnode_features)
    #         continue
            
    #     features_for_this_type = node_features[type_mask]
        
    #     # Calculate marginal for features given this node type
    #     if len(features_for_this_type) > 0:
    #         uniques_feat, counts_feat = torch.unique(features_for_this_type, return_counts=True)
    #         feat_probs = counts_feat / len(features_for_this_type)
            
    #         # Create conditional PMF for this node type
    #         type_feature_pmf = [feat_probs[uniques_feat == i].item() if i in uniques_feat else 0 
    #                         for i in range(cfg.dataset.nnode_features)]
    #         node_feature_pmf_given_type.append(type_feature_pmf)

    # Store the results
    cfg.node_type_pmf = node_type_pmf
    # cfg.node_feature_pmf_given_type = node_feature_pmf_given_type
    cfg.edge_ratio = edge_ratio / i


def dump_splits(splits):
    r"""Dumps the split idx to the output directory specified in
    :obj:`cfg.out_dir`.

    Args:
        splits (list): List of train, val and test indices
    """
    os.makedirs(cfg.out_dir, exist_ok=True)
    split_file = os.path.join(cfg.out_dir, 'splits')
    with open(split_file, 'wb') as f:
        pickle.dump(splits, f)


def preprocess_edges(dataset):
    """
    Updates edge_index and edge_attr of each Data object in the dataset
    to ensure all edges are undirected.

    Add a `triu_edge_index` attribute containing N * (N-1) / 2 edges, where N is the number of nodes,
    from the upper triangular part of adjacency matrix.
    
    Args:
        dataset (Dataset): A collection of PyTorch Geometric Data objects.
                               
    Returns:
        Dataset: The dataset with updated edge_index and edge_attr.
    """
    data_list = []
    for data in dataset:
        num_nodes = data.num_nodes

        # Generate all possible upper triangular connections (symmetric graph)
        row, col = torch.triu_indices(num_nodes, num_nodes, offset=1)
        all_connections = torch.stack((row, col), dim=0)

        # Directed -> undirected, map original edges to their symmetric representation
        src, dst = data.edge_index
        symmetric_src = torch.min(src, dst)
        symmetric_dst = torch.max(src, dst)
        symmetric_edges = torch.cat([
            torch.stack((symmetric_src, symmetric_dst), dim=0),
            torch.stack((symmetric_dst, symmetric_src), dim=0)
        ], dim=1)
        # Ensure there are no duplicates
        symmetric_edges = data.edge_index.new(list(set(map(tuple, symmetric_edges.T.tolist())))).T # -> shape [2, 2 * num_directed_edges]

        # Attributes
        edge_set = np.array(list(map(set, torch.stack((symmetric_src, symmetric_dst), dim=0).T.tolist())))
        attr = data.edge_attr.new_zeros(symmetric_edges.size(1))
        for i in range(len(attr)):
            idx = edge_set == set(symmetric_edges[:, i].tolist())
            attr[i] = data.edge_attr[idx][0].item()
        
        # Update the edge_index and edge_attr
        data.edge_index = symmetric_edges
        data.edge_attr = attr
        data.triu_edge_index = all_connections

        # Save node features as a separate batch argument `x_features`, and keep only node type as `data.x`.
        # This is done only for ocb.
        if cfg.dataset.name.startswith('ocb'): # data.x is not None and data.x.size(1) > 1:
            data.x_features = data.x[:, [1]]
            if "AnalogGenie" not in dataset.root: # ??
                data.x = data.x[:, [0]]

        # Finally, if we are doing pin prediction or use supernodes, add a triu_learnable_edge_attr attribute to the graph
        if cfg.dataset.get("task_type", '') == 'pin_prediction':
            adj_triu_size, n_learnable_edges = data.triu_edge_index.size(1), data.learnable_edge_index.size(1)
            triu_edge_index, triu_learnable_edge_attr = torch_sparse.coalesce(
                torch.cat([data.triu_edge_index, data.learnable_edge_index], dim=1), 
                torch.cat([data.x.new_zeros(adj_triu_size), data.x.new_ones(n_learnable_edges)]), data.num_nodes, data.num_nodes,
                op="max"
            )
            data.triu_learnable_edge_attr = triu_learnable_edge_attr

        if cfg.gt.get('conditional_gen', False) and cfg.gt.get('supernode', False):
            data.triu_learnable_edge_attr = data.triu_edge_index[1] != (data.num_nodes - 1)

        data_list.append(data)

    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)
    return dataset


def preprocess_data_y(dataset):
    """
    Encodes y features of a circuit dataset using RBF functions. The centroids of the RBFs are found using K-means separately on
    each feature.
    """

    ys = []
    for i in range(len(dataset)):
        ys.append(dataset[i].y)
    features = np.concatenate(ys) # Shape ds_len, n_features
    y_std = features.std(axis=0)
    features = features / y_std
    n_feats = features.shape[-1]
    cfg.y_std = y_std.tolist()

    # Get RBF centroids
    kmeans_centroids = get_kmeans_centroids_per_feature(features, cfg.gt.n_rbf_centroids)
    cfg.kmeans_centroids = [feat_centroids.tolist() for feat_centroids in kmeans_centroids]

    # Flatten and repeat for distance calculation
    flat_centroids = torch.cat([feat_centroids.repeat(len(features)) for feat_centroids in kmeans_centroids])
    flat_y_normed = torch.cat([torch.tensor(features[:, i]).repeat_interleave(cfg.gt.n_rbf_centroids) for i in range(n_feats)])
    distances = torch.abs(flat_y_normed - flat_centroids).view(n_feats, len(features), cfg.gt.n_rbf_centroids) # L1? L2?

    # Gaussian kernels with temperature coefficients
    temp = torch.tensor([0.5, 1, 0.1])
    unnorm_rbf = torch.exp(-distances * 0.5 / temp[:, None, None])
    rbf = unnorm_rbf / unnorm_rbf.sum(dim=-1, keepdim=True)

    # Update data
    data_list = []
    for i, data in enumerate(dataset):
        # Add c for the conditioning feature
        data.c_init = rbf[[cfg.gt.conditional_dim], i].unsqueeze(0).float().to(data.y.device) # rbf[:, i].unsqueeze(0).float().to(data.y.device)
        data_list.append(data)

    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)

    return dataset


def get_kmeans_centroids_per_feature(features, n_centroids, iter=0):
    
    feat_scale = [1.03, 1.06, 1.45]
    kmeans_centroids = []
    for i in range(features.shape[-1]):
        feat_centroids = []
        iter = 0
        while len(feat_centroids) < n_centroids:
            feat_centroids = kmeans(features[:, i], int(n_centroids * feat_scale[i]))[0]
            iter += 1
        print(f'~~ RBF centroids of feat #{i} found after {iter} k-means steps ~~')
        kmeans_centroids.append(torch.tensor(feat_centroids[:n_centroids]))

    return kmeans_centroids


def prepro_add_supernodes(dataset):

    data_list = []
    
    for data in dataset:
        graph = data.clone()

        # Update x
        # graph.x = torch.cat([graph.x, graph.y[0][None, :]])
        if cfg.gnn.n_spec == 1:
            mult = torch.tensor([0, 0, 0])
            mult[cfg.gnn.spec_dim] = 1
        else:
            mult = torch.tensor([1, cfg.gnn.n_bins, cfg.gnn.n_bins ** 2])
        supernode_type = (graph.y[0] * mult).sum().repeat(3)[None, :]
        supernode_type[0, 1:] = 50 # Will be scaled down to zero
        graph.x = torch.cat([graph.x, supernode_type])
        graph.supernode_x_index = torch.tensor(len(graph.x) - 1)

        data_list.append(graph)
    
    # Update dataset
    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)
    return dataset


def preprocess_pins(dataset):
    """
    Preprocesses pin information for each graph in the dataset.
    Adds pin nodes and connections to the graphs based on node types.
    This optimization moves the pin preprocessing from training time to data loading time.
    
    Args:
        dataset (Dataset): A collection of PyTorch Geometric Data objects.
                               
    Returns:
        Dataset: The dataset with preprocessed pin information.
    """
    data_list = []
    
    for data in dataset:
        graph = data.clone()
        edge_index_for_loss = []
        # labels = []
        x_to_append = []
        num_nodes = graph.num_nodes
        parents_to_neighbors = []

        if len(graph.y) == 0:
            continue

        for parent_node, neighbors in graph.y:
            
            parent_node_typename = ID_TO_NAME_NODES[graph.x[parent_node, 0].item()]
            pin_types = torch.tensor([NAME_TO_ID_PINS[tname] for tname in node2pins[parent_node_typename]])
            pin_indices = torch.from_numpy(np.arange(num_nodes, num_nodes + len(pin_types)))
            num_nodes += len(pin_indices)
            
            # Append to graph.x
            x_to_append.append(torch.stack([pin_types, torch.full((len(pin_types),), fill_value=graph.x[parent_node, 1].item())]).T)

            # Add egdes from parent node to pins and from pins to neighbors
            parent_to_pins = torch.stack([torch.full((len(pin_indices),), fill_value=parent_node), pin_indices])
            unique_neighbors = torch.tensor(np.unique(neighbors))
            # /!\ The following line must be uncommented if training classification pin assigment model -> to fix
            # pins_to_neighbors = torch.stack([unique_neighbors.repeat_interleave(len(pin_indices)), pin_indices.repeat(len(unique_neighbors))])
            # Add only GT edges from pins to neighbors
            pins_to_neighbors = torch.stack([torch.tensor(neighbors), pin_indices])
            new_edges = torch.cat([parent_to_pins, pins_to_neighbors], dim=1) # Flip will be done in preprocess_edges


            # Edges between parent node and neighbors will be suppressed
            parents_to_neighbors.append(torch.stack([torch.tensor(parent_node).repeat(len(unique_neighbors)), 
                                                     unique_neighbors]).to(graph.edge_index.device))

            graph.edge_index = torch.cat([graph.edge_index, new_edges.to(graph.edge_index.device)], dim=1)
            graph.edge_attr = torch.cat([graph.edge_attr, graph.edge_attr.new_ones(new_edges.size(1))])

            ## Add two arguments to the graph for the loss: learnable edges and the label
            # learnable_edges = pins_to_neighbors
            # Coalesce to obtain y
            nnode_coalesce = (pin_indices.max() + 1).item()
            pins_to_all_neighbors = torch.stack([unique_neighbors.repeat_interleave(len(pin_indices)), pin_indices.repeat(len(unique_neighbors))])
            learnable_edges, new_y_labels = torch_sparse.coalesce(
                torch.cat([pins_to_all_neighbors, pins_to_neighbors], dim=1),
                torch.cat([torch.zeros(pins_to_all_neighbors.size(1)), torch.ones(len(neighbors))], dim=0),
                nnode_coalesce, nnode_coalesce,
                op="max"
            )
            # Append edge indices on which the loss will be applied
            edge_index_for_loss.append(learnable_edges)
            # labels.append(new_y_labels)

        # Then, out of the for loop
        edge_index_for_loss = torch.cat(edge_index_for_loss, dim=1)
        graph.learnable_edge_index = edge_index_for_loss.to(graph.x.device)
        # graph.labels = torch.cat(labels).to(graph.x.device)

        # Suppress edges between parent node and neighbors
        parents_to_neighbors = torch.cat(parents_to_neighbors, dim=1)
        nnode_coalesce = (graph.edge_index.max() + 1).item()
        new_idx, new_attr = torch_sparse.coalesce(
            torch.cat([graph.edge_index, torch.cat([parents_to_neighbors, parents_to_neighbors.flip(dims=[0])], dim=1)], dim=1),
            torch.cat([graph.edge_attr, torch.zeros(2 * parents_to_neighbors.size(1)).to(graph.edge_index.device)], dim=0),
            nnode_coalesce, nnode_coalesce,
            op="min"
        )
        keep_indices = (new_attr > 0) & (new_idx[0, :] < new_idx[1, :])
        graph.edge_index = new_idx[:, keep_indices]
        graph.edge_attr = new_attr[keep_indices]

        # Update x
        new_types = torch.tensor([NAME_TO_ID_PINS[ID_TO_NAME_NODES[t.item()]] for t in graph.x[:, 0]])[:, None]
        new_x = torch.cat([new_types, graph.x[:, [1]]], dim=1)
        graph.x = torch.cat([new_x, torch.cat(x_to_append, dim=0)], dim=0).to(graph.x.device)

        data_list.append(graph)
    
    # Update dataset
    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)
    return dataset


# def add_all_connections(dataset, special_value=0):
#     """
#     Updates edge_index and edge_attr of each Data object in the dataset
#     to include all possible directed connections, assigning a 
#     special value to edges that were not originally connected.

#     Number of edges is N*(N-1), where N is the number of nodes.
    
#     Args:
#         dataset (Dataset): A collection of PyTorch Geometric Data objects.
#         special_value (float): The value to assign to edges that are not 
#                                originally connected.
                               
#     Returns:
#         Dataset: The dataset with updated edge_index and edge_attr.
#     """
#     data_list = []
#     for data in dataset:
#         num_nodes = data.num_nodes
#         # Generate all possible directed connections
#         row, col = torch.meshgrid(torch.arange(num_nodes), torch.arange(num_nodes))
#         all_connections = torch.stack((row.flatten(), col.flatten()), dim=0)
#         # Remove self-loops
#         mask = all_connections[0] != all_connections[1]
#         all_connections = all_connections[:, mask]
#         # Create a tensor to track if an edge exists in the original edge_index
#         is_original_edge = torch.zeros(all_connections.size(1), dtype=torch.bool)
#         for i in range(data.edge_index.size(1)):
#             src, dst = data.edge_index[:, i]
#             edge_idx = (all_connections[0] == src) & (all_connections[1] == dst)
#             is_original_edge |= edge_idx
#         # Create the updated edge_attr
#         if data.edge_attr is not None:
#             # Create a tensor of dimension all_connections.size(1) filled with the special value
#             updated_edge_attr = torch.full((all_connections.size(1),), special_value, dtype=data.edge_attr.dtype)
#             # Copy the original edge attributes to the appropriate positions
#             updated_edge_attr[is_original_edge] = data.edge_attr
#         else:
#             # Initialize all edge attributes to the special value
#             updated_edge_attr = torch.full((all_connections.size(1), 1), special_value, dtype=torch.float)
        
#         # Update the edge_index and edge_attr
#         data.edge_index = all_connections
#         data.edge_attr = updated_edge_attr
#         data.num_edges = all_connections.size(1)
#         data_list.append(data)
#     dataset._indices = None
#     dataset._data_list = data_list
#     dataset.data, dataset.slices = dataset.collate(data_list)
#     return dataset

# def compute_indegree_histogram(dataset):
#     """Compute histogram of in-degree of nodes needed for PNAConv.

#     Args:
#         dataset: PyG Dataset object

#     Returns:
#         List where i-th value is the number of nodes with in-degree equal to `i`
#     """
#     from torch_geometric.utils import degree

#     deg = torch.zeros(1000, dtype=torch.long)
#     max_degree = 0
#     for data in dataset:
#         d = degree(data.edge_index[1],
#                    num_nodes=data.num_nodes, dtype=torch.long)
#         max_degree = max(max_degree, d.max().item())
#         deg += torch.bincount(d, minlength=deg.numel())
#     return deg.numpy().tolist()[:max_degree + 1]


def preformat_ocb(dataset_dir): 
    dataset = join_dataset_splits(
        [OCBDataset(root=dataset_dir, split=split, use_pins=cfg.dataset.use_pins, large_idx=cfg.dataset.get('large_idx', (False, '')),
                    version=cfg.dataset.version, n_bins=cfg.gnn.n_bins)
         for split in ['train', 'val', 'test']]
    )
    return dataset


def preformat_analogenie(dataset_dir):
    dataset = join_dataset_splits(
        [AnalogGenieDataset(root=dataset_dir, split = split, pins= cfg.dataset.pins,pin_prediction=cfg.dataset.pin_prediction) 
            for split in ['train', 'val', 'test']]
    )
    return dataset

def preformat(dataset_dir):
    """
    Preprocess the dataset directory to load the dataset.

    Args:
        dataset_dir (str): The directory where the dataset is stored.

    Returns:
        Dataset: The preprocessed dataset.
    """
    if cfg.dataset.name.startswith('AnalogGenie'):
        return preformat_analogenie(dataset_dir)
    elif cfg.dataset.name.startswith('ocb'):
        return preformat_ocb(dataset_dir)
    else:
        raise ValueError(f"Unknown dataset name: {cfg.dataset.name}")

def join_dataset_splits(datasets):
    """Join train, val, test datasets into one dataset object.

    Args:
        datasets: list of 3 PyG datasets to merge

    Returns:
        joint dataset with `split_idxs` property storing the split indices
    """
    assert len(datasets) == 3, "Expecting train, val, test datasets"

    n1, n2, n3 = len(datasets[0]), len(datasets[1]), len(datasets[2])
    data_list = [datasets[0].get(i) for i in range(n1)] + \
                [datasets[1].get(i) for i in range(n2)] + \
                [datasets[2].get(i) for i in range(n3)]

    datasets[0]._indices = None
    datasets[0]._data_list = data_list
    datasets[0].data, datasets[0].slices = datasets[0].collate(data_list)
    split_idxs = [list(range(n1)),
                  list(range(n1, n1 + n2)),
                  list(range(n1 + n2, n1 + n2 + n3))]
    datasets[0].split_idxs = split_idxs

    return datasets[0]