import os
import yaml
from ml_collections import ConfigDict

import torch
import torch_geometric
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset
import numpy as np

from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash

from robust_diffusion.data import SparseGraph
from robust_diffusion.helper import utils as robust_utils
from robust_diffusion.train import train


from robust_diffusion.data import SparseGraph, largest_connected_components
from torch_geometric.utils import to_scipy_sparse_matrix


def load_dataset(dataset_name, dataset_root, device="cpu"):
    # Load the dataset from the given dataset_root
    if dataset_name in ["cora", "cora_ml", "pubmed", "citeseer"]:
        if dataset_name in ["cora_ml", "pubmed"]:
            dataset_obj = torch_geometric.datasets.CitationFull(
                root=dataset_root, name=dataset_name)
        else:
            dataset_obj = torch_geometric.datasets.Planetoid(
                root=dataset_root, name=dataset_name)
        dataset = dataset_obj[0].to(device)
        dataset_info = ConfigDict()
        dataset_info.n_features = dataset_obj.num_features
        dataset_info.n_classes = dataset_obj.num_classes
        dataset_info.n_nodes = dataset_obj.data.x.shape[0]
        dataset_info.dataset_name = dataset_name

        # dataset preprocessing with the largest connected component
        attr_matrix = dataset.x.cpu().numpy()
        adj_matrix = to_scipy_sparse_matrix(dataset.edge_index, num_nodes=dataset.x.shape[0])
        labels = dataset.y
        graphs = SparseGraph(adj_matrix, attr_matrix, labels)

        prep_graphs = largest_connected_components(sparse_graph=graphs, n_components=1, make_undirected=True)

        dataset.x = torch.tensor(prep_graphs.attr_matrix)
        dataset.edge_index = torch.tensor(prep_graphs.get_edgeid_to_idx_array().T, dtype=torch.long)
        dataset.y = torch.tensor(prep_graphs.labels)

        dataset_info.n_nodes = dataset.x.shape[0]
        return dataset, dataset_info
    else:
        raise NotImplementedError(
            f"Load method {dataset_name} not implemented yet.")

def check_dataset_valid(split_record, training_split, validation_split, training_split_type, 
                        validation_split_type, test_split, test_split_type, splits_root):
    
    data = torch.load(os.path.join(splits_root, split_record))
    data_config = data.get("config", dict())
    # TODO check if the dataset is of the same configuration
    dataset_valid = ((data_config.get(training_split, None) == training_split) &
        (data_config.get(validation_split, None) == validation_split) &
        (data_config.get(training_split_type, None) == training_split_type) &
        (data_config.get(validation_split_type, None) == validation_split_type) &
        (data_config.get(test_split, None) == test_split) &
        (data_config.get(test_split_type, None) == test_split_type)) 
    return dataset_valid

def dataset_split_name(dataset_name, split_name):
    return f"{dataset_name}-{split_name}.pt"

def make_dataset_splits(dataset_name, 
                        training_split=None, validation_split=None, training_split_type=None, validation_split_type=None, 
                        test_split=None, test_split_type=None, 
                        inductive=False, default_dataset_configs=None, dataset_root="data", splits_root="splits", device="cpu"):
    if default_dataset_configs is None:
        raise ValueError("default_dataset_configs must be provided if the dataset is not already loaded")
    
    # Loading the dataset, creating splits, and saving them (for both transductive and inductive)
    dataset, dataset_info = load_dataset(dataset_name, dataset_root)

    if training_split is None:
        training_split = default_dataset_configs.get("training_split")
        training_split_type = default_dataset_configs.get("training_split_type")
    if validation_split is None:
        validation_split = default_dataset_configs.get("validation_split")
        validation_split_type = default_dataset_configs.get("validation_split_type")
    if test_split is None:
        test_split = default_dataset_configs.get("test_split")
        test_split_type = default_dataset_configs.get("test_split_type")
    
    if test_split_type == "fraction":
        test_split_final = int(test_split * dataset_info.n_nodes)
        test_split_type_final = "overall"

    split = SplitManager(dataset)
    training_mask = split.alloc(budget=training_split if training_split_type == "stratified" else (training_split*dataset_info.n_classes),
                                budget_allocated='per_class' if training_split_type == 'stratified' else 'overall',
                                stratified=(training_split_type == 'stratified'))
    training_idx = training_mask.nonzero(as_tuple=True)[0]
    validation_mask = split.alloc(budget=validation_split if training_split_type == "stratified" else (training_split*dataset_info.n_classes),
                                budget_allocated='per_class' if validation_split_type == 'stratified' else 'overall',
                                stratified=(validation_split_type == 'stratified'))
    validation_idx = validation_mask.nonzero(as_tuple=True)[0]

    test_mask = split.alloc(
        budget=test_split_final, budget_allocated='overall', stratified=True if test_split_type_final == 'stratified' else False)

    unlabeled_mask = ~(training_mask | validation_mask | test_mask)

    if inductive:
        training_dataset = node_induced_subgraph(dataset, training_mask  | unlabeled_mask)
        validation_dataset = node_induced_subgraph(dataset, training_mask | validation_mask | unlabeled_mask)
        test_dataset = dataset
    else:
        training_dataset = dataset
        validation_dataset = dataset
        test_dataset = dataset

    training_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        training_dataset.edge_index, num_nodes=dataset_info.n_nodes),
        attr_matrix=training_dataset.x.cpu().numpy(), 
        labels=training_dataset.y)
    training_attr = torch.FloatTensor(training_graph.attr_matrix).to(device)
    training_adj = robust_utils.sparse_tensor(training_graph.adj_matrix.tocoo()).to(device)
    
    validation_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        validation_dataset.edge_index, num_nodes=dataset_info.n_nodes),
        attr_matrix=validation_dataset.x.cpu().numpy(), 
        labels=validation_dataset.y)
    validation_attr = torch.FloatTensor(validation_graph.attr_matrix).to(device)
    validation_adj = robust_utils.sparse_tensor(validation_graph.adj_matrix.tocoo()).to(device)
    
    labels = training_dataset.y.to(device)

    test_attr = training_attr if ~inductive else torch.FloatTensor(test_dataset.x).to(device)
    test_adj = training_adj  if ~inductive else torch_geometric.utils.to_scipy_sparse_matrix(test_dataset.edge_index).tocoo()

    split_dict = {"training": training_mask.nonzero(as_tuple=True)[0], "validation": validation_mask.nonzero(as_tuple=True)[0],
                  "unlabeled": unlabeled_mask.nonzero(as_tuple=True)[0], "test": test_mask.nonzero(as_tuple=True)[0],
                  "config": {
                        "training_split": training_split, "training_split_type": training_split_type,
                        "validation_split": validation_split, "validation_split_type": validation_split_type,
                        "test_split": test_split, "test_split_type": test_split_type
                  }}
    split_name = TensorHash.hash_tensor_dict(split_dict)
    try:
        os.makedirs(splits_root, exist_ok=True)
        torch.save(split_dict, os.path.join(splits_root, dataset_split_name(dataset_name, split_name)))
    except RuntimeError as e:
        print(f"Error saving the split: {e}")

    return {
        "training_attr": training_attr, 
        "training_adj": training_adj, 
        "validation_attr": validation_attr,
        "validation_adj": validation_adj,
        "labels": labels, 
        "training_idx": training_idx, 
        "validation_idx": validation_idx, 
        "test_attr": test_attr, 
        "test_adj": test_adj, 
        "unlabeled_mask": unlabeled_mask,
        "test_mask": test_mask, 
        "dataset_info": dataset_info, 
        "split_name": split_name,
        "config": split_dict.get("config", dict())
    }

def load_dataset_splits(dataset_name, split_name, inductive=False, dataset_root="data", splits_root="splits", device="cpu"):
    dataset, dataset_info = load_dataset(dataset_name, dataset_root)
    split_dict = torch.load(os.path.join(splits_root, dataset_split_name(dataset_name, split_name)))
    training_mask = torch.zeros(dataset_info.n_nodes, dtype=torch.bool)
    training_mask[split_dict["training"]] = True
    validation_mask = torch.zeros(dataset_info.n_nodes, dtype=torch.bool)
    validation_mask[split_dict["validation"]] = True
    unlabeled_mask = torch.zeros(dataset_info.n_nodes, dtype=torch.bool)
    unlabeled_mask[split_dict["unlabeled"]] = True
    test_mask = torch.zeros(dataset_info.n_nodes, dtype=torch.bool)
    test_mask[split_dict["test"]] = True

    assert (training_mask | validation_mask | unlabeled_mask | test_mask).min().item()
    assert (training_mask & validation_mask).max().item() == False
    assert (training_mask & test_mask).max().item() == False
    assert (validation_mask & test_mask).max().item() == False

    if inductive:
        training_dataset = node_induced_subgraph(dataset, training_mask | unlabeled_mask)
        validation_dataset = node_induced_subgraph(dataset, training_mask | validation_mask | unlabeled_mask)
        test_dataset = dataset
    else:
        training_dataset = dataset
        validation_dataset = dataset
        test_dataset = dataset

    training_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        training_dataset.edge_index, num_nodes=dataset_info.n_nodes),
        attr_matrix=training_dataset.x.cpu().numpy(), 
        labels=training_dataset.y)
    training_attr = torch.FloatTensor(training_graph.attr_matrix).to(device)
    training_adj = robust_utils.sparse_tensor(training_graph.adj_matrix.tocoo()).to(device)

    validation_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        validation_dataset.edge_index, num_nodes=dataset_info.n_nodes),
        attr_matrix=validation_dataset.x.cpu().numpy(), 
        labels=validation_dataset.y)
    validation_attr = torch.FloatTensor(validation_graph.attr_matrix).to(device)
    validation_adj = robust_utils.sparse_tensor(validation_graph.adj_matrix.tocoo()).to(device)

    labels = training_dataset.y.to(device)

    test_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        test_dataset.edge_index, num_nodes=dataset_info.n_nodes),
        attr_matrix=test_dataset.x.cpu().numpy(), 
        labels=test_dataset.y)
    test_attr = torch.FloatTensor(test_graph.attr_matrix).to(device)
    test_adj = robust_utils.sparse_tensor(test_graph.adj_matrix.tocoo()).to(device)

    return {
        "training_attr": training_attr, 
        "training_adj": training_adj, 
        "validation_attr": validation_attr,
        "validation_adj": validation_adj,
        "labels": labels, 
        "training_idx": split_dict["training"], 
        "validation_idx": split_dict["validation"], 
        "test_attr": test_attr, 
        "test_adj": test_adj, 
        "unlabeled_mask": unlabeled_mask,
        "test_mask": test_mask, 
        "dataset_info": dataset_info, 
        "split_name": split_name,
        "config": split_dict.get("config", dict())
    }


def make_arxiv_dataset_splits(dataset_name, splits_root="splits", device="cpu"):

    assert dataset_name == "ogbn-arxiv"
    dataset = PygNodePropPredDataset(name = dataset_name, root = "datasets/", transform=T.Compose([T.ToUndirected(), T.ToSparseTensor()])) 

    split_idx = dataset.get_idx_split()
    train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

    suffix = torch.tensor(np.random.randint(0, 100_000_000))
    split_dict = {"training": train_idx, "validation": valid_idx,
                  "test": test_idx, "suffix": suffix}
    split_name = TensorHash.hash_tensor_dict(split_dict)
    try:
        os.makedirs(splits_root, exist_ok=True)
        torch.save(split_dict, os.path.join(splits_root, dataset_split_name(dataset_name.replace("-", "_"), split_name)))
    except RuntimeError as e:
        print(f"Error saving the split: {e}")

    return {
        "training_idx": train_idx, 
        "validation_idx": valid_idx,
        "test_idx": test_idx,
        "split_name": split_name,
    }

def load_arxiv_dataset_splits(dataset_name, split_code, inductive, dataset_root="datasets/", splits_root="splits", device="cpu"):

    assert dataset_name == "ogbn-arxiv"
    dataset = PygNodePropPredDataset(name = dataset_name, root = dataset_root, transform=T.Compose([T.ToUndirected(), T.ToSparseTensor()])) 
    graph = dataset[0]
    num_nodes = graph.num_nodes
    num_features = graph.x.shape[1]
    dataset_info = ConfigDict()
    dataset_info.n_features = num_features
    dataset_info.n_classes = len(graph.y.unique())
    dataset_info.n_nodes = num_nodes
    dataset_info.dataset_name = dataset_name

    split_dict = torch.load(os.path.join(splits_root, dataset_split_name(dataset_name.replace("-", "_"), split_code)))
    training_mask = torch.zeros(num_nodes, dtype=torch.bool)
    training_mask[split_dict["training"]] = True
    validation_mask = torch.zeros(num_nodes, dtype=torch.bool)
    validation_mask[split_dict["validation"]] = True
    unlabeled_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask[split_dict["test"]] = True

    assert (training_mask | validation_mask | unlabeled_mask | test_mask).min().item()
    assert (training_mask & validation_mask).max().item() == False
    assert (training_mask & test_mask).max().item() == False
    assert (validation_mask & test_mask).max().item() == False

    if inductive:
        training_dataset = node_induced_subgraph(dataset, training_mask | unlabeled_mask)
        validation_dataset = node_induced_subgraph(dataset, training_mask | validation_mask | unlabeled_mask)
        test_dataset = dataset
    else:
        training_dataset = dataset
        validation_dataset = dataset
        test_dataset = dataset

    training_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        training_dataset.edge_index, num_nodes=num_nodes),
        attr_matrix=training_dataset.x.cpu().numpy(), 
        labels=training_dataset.y)
    training_attr = torch.FloatTensor(training_graph.attr_matrix).to(device)
    training_adj = robust_utils.sparse_tensor(training_graph.adj_matrix.tocoo()).to(device)

    validation_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        validation_dataset.edge_index, num_nodes=num_nodes),
        attr_matrix=validation_dataset.x.cpu().numpy(), 
        labels=validation_dataset.y)
    validation_attr = torch.FloatTensor(validation_graph.attr_matrix).to(device)
    validation_adj = robust_utils.sparse_tensor(validation_graph.adj_matrix.tocoo()).to(device)

    test_graph = SparseGraph(adj_matrix=torch_geometric.utils.to_scipy_sparse_matrix(
        test_dataset.edge_index, num_nodes=num_nodes),
        attr_matrix=test_dataset.x.cpu().numpy(), 
        labels=test_dataset.y)
    test_attr = torch.FloatTensor(test_graph.attr_matrix).to(device)
    test_adj = robust_utils.sparse_tensor(test_graph.adj_matrix.tocoo()).to(device)

    labels = training_dataset.y.squeeze().to(device)

    return {
        "training_attr": training_attr, 
        "training_adj": training_adj, 
        "validation_attr": validation_attr,
        "validation_adj": validation_adj,
        "test_attr": test_attr, 
        "test_adj": test_adj, 
        "labels": labels, 
        "training_idx": split_dict["training"], 
        "validation_idx": split_dict["validation"], 
        "unlabeled_mask": unlabeled_mask,
        "test_mask": test_mask, 
        "split_name": split_code,
        "dataset_info": dataset_info
    }