import torch
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, Flickr, Reddit
import torch_geometric.transforms as T
from torch_geometric.transforms import NormalizeFeatures, LargestConnectedComponents, Compose
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset
from torch_geometric.utils import to_undirected

def split_indices(label, train_size=20, val_size=30):
    train_index, val_index, test_index = [], [], []
    for j in range(label.max().item() + 1):
        index = torch.where(label == j)[0]
        if len(index) < train_size + val_size:
            raise ValueError(f"Not enough samples in class {j} to meet the requirement of {train_size} training and {val_size} validation nodes.")
        index = index[torch.randperm(len(index))]
        train_index.extend(index[:train_size])
        val_index.extend(index[train_size:train_size + val_size])
        test_index.extend(index[train_size + val_size:])
    return torch.tensor(train_index), torch.tensor(val_index), torch.tensor(test_index)

def split_indices_planetoid(label, train_size=20, total_val_size=500, total_test_size=1000):
    train_index, remaining_index = [], []

    # Step 1: Collect 20 samples per class for training, and gather remaining indices
    for j in range(label.max().item() + 1):
        index = torch.where(label == j)[0]
        if len(index) < train_size:
            raise ValueError(f"Not enough samples in class {j} to meet the requirement of {train_size} training nodes.")
        index = index[torch.randperm(len(index))]  
        train_index.extend(index[:train_size])     
        remaining_index.extend(index[train_size:]) 

    # Step 2: Shuffle the remaining indices
    remaining_index = torch.tensor(remaining_index)[torch.randperm(len(remaining_index))]

    # Step 3: Select 500 for validation and 1000 for testing from the remaining indices
    if len(remaining_index) < total_val_size + total_test_size:
        raise ValueError("Not enough remaining samples to meet the requirement of 500 validation and 1000 test nodes.")
    
    val_index = remaining_index[:total_val_size]
    test_index = remaining_index[total_val_size:total_val_size + total_test_size]

    return torch.tensor(train_index), val_index, test_index

def set_masks(data, train_index, val_index, test_index):
    data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.train_mask[train_index] = True
    data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.val_mask[val_index] = True
    data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.test_mask[test_index] = True

def preprocess_edges(data, edge_ratio):
    # Store the original edge_index before preprocessing
    data.original_edge_index = data.edge_index.clone().detach()

    if edge_ratio == 1.0:
        return data
    
    # Number of edges in the graph
    num_edges = data.edge_index.size(1)

    # Convert edge_index to a list of edges
    edges = data.edge_index.t().tolist()

    # Build mapping from undirected edges to indices of their directed counterparts
    undirected_edge_to_indices = {}
    for idx, (u, v) in enumerate(edges):
        undirected_edge = tuple(sorted([u, v]))  # Treat edge as undirected by sorting the node pair
        if undirected_edge not in undirected_edge_to_indices:
            undirected_edge_to_indices[undirected_edge] = []
        undirected_edge_to_indices[undirected_edge].append(idx)

    # Get all unique undirected edges
    all_undirected_edges = list(undirected_edge_to_indices.keys())

    # Calculate number of undirected edges to keep
    num_keep_undirected_edges = int(edge_ratio * len(all_undirected_edges))

    # Randomly select undirected edges to keep
    keep_undirected_edges = torch.randperm(len(all_undirected_edges))[:num_keep_undirected_edges]

    # Collect indices of both directions for the selected undirected edges
    keep_indices = set()
    for idx in keep_undirected_edges:
        undirected_edge = all_undirected_edges[idx]
        keep_indices.update(undirected_edge_to_indices[undirected_edge])

    # Create a mask to keep the selected edges
    edge_mask = torch.zeros(num_edges, dtype=torch.bool)
    edge_mask[list(keep_indices)] = True

    # Update the edge_index with the selected edges
    data.edge_index = data.edge_index[:, edge_mask]

    return data

def load_dataset(args, data_dir='./data'):
    name = args.dataset.lower()
    edge_ratio = args.edge_ratio

    # Load the dataset
    if name in ['cora', 'citeseer', 'pubmed']:
        dataset = Planetoid(root=data_dir, name=name.capitalize(),  force_reload=True, transform=LargestConnectedComponents())
    elif name in ['computers', 'photo']:
        # For the GIN, feature normalization is applied since it is not trainable without it
        if args.architecture == "gin":
            transform = Compose([LargestConnectedComponents(), NormalizeFeatures()])
        else:
            transform = LargestConnectedComponents()
        dataset = Amazon(root=data_dir, name=name.capitalize(), force_reload=True, transform= transform)
    elif name in ['cs', 'physics']:
        dataset = Coauthor(root=data_dir, name=name.capitalize(),  force_reload=True, transform=LargestConnectedComponents())
    elif name in ['ogbn-arxiv']:
        dataset = PygNodePropPredDataset(root=data_dir, name=name)
    elif name in ['reddit']:
        dataset = Reddit(data_dir + "/Reddit")
    elif name in ['flickr']:
        dataset = Flickr(data_dir+ "/Flickr")
    else:
        raise ValueError(f"Unknown dataset: {name}")

    data = dataset[0]
    data.num_classes = dataset.num_classes
    print(data.x.size()[0], data.x.size()[1])
    data.num_nodes = data.x.size()[0]

    # Split the dataset
    if name in ['cora', 'citeseer', 'pubmed']:
        train_index, val_index, test_index = split_indices_planetoid(data.y)
        set_masks(data, train_index, val_index, test_index)
    elif name in ['ogbn-arxiv']:
        split_idx = dataset.get_idx_split()
        train_index, val_index, test_index = split_idx['train'], split_idx['valid'], split_idx['test']
        set_masks(data, train_index, val_index, test_index)
    elif name in ['computers', 'photo', 'cs', 'physics']:
        train_index, val_index, test_index = split_indices(data.y)
        set_masks(data, train_index, val_index, test_index)

    # Convert to undirected graph by adding reverse edges
    if name == "ogbn-arxiv":
        data.edge_index = to_undirected(data.edge_index)
    
    # Exclude the (1-{edge_ratio}) of edges from the graph
    data = preprocess_edges(data, edge_ratio)

    return data