import os
import time

import torch
from torch.utils.data import DataLoader

from torch_geometric.loader import DataLoader as pygDataLoader
# import wandb
from BUDDY.hashdataset import get_hashed_train_val_test_datasets, make_train_eval_data, HashDataset

def get_loaders(args, dataset, splits, directed):
    train_data, val_data, test_data = splits['train'], splits['valid'], splits['test']
    # if args.model in {'ELPH', 'BUDDY'}:
    train_dataset, val_dataset, test_dataset = get_hashed_train_val_test_datasets(dataset, train_data, val_data,
                                                                                      test_data, args, directed)
   

    dl = DataLoader if args.model in {'ELPH', 'BUDDY'} else pygDataLoader
    train_loader = dl(train_dataset, batch_size=args.batch_size,
                      shuffle=True, num_workers=args.num_workers)
    # as the val and test edges are often sampled they also need to be shuffled
    # the citation2 dataset has specific negatives for each positive and so can't be shuffled
    shuffle_val = False if args.data_name.startswith('ogbl-citation') else True
    val_loader = dl(val_dataset, batch_size=args.batch_size, shuffle=shuffle_val,
                    num_workers=args.num_workers)
    shuffle_test = False if args.data_name.startswith('ogbl-citation') else True
    test_loader = dl(test_dataset, batch_size=args.batch_size, shuffle=shuffle_test,
                     num_workers=args.num_workers)
    if (args.data_name == 'ogbl-citation2') and (args.model in {'ELPH', 'BUDDY'}):
        train_eval_loader = dl(
            make_train_eval_data(args, train_dataset, train_data.num_nodes,
                                  n_pos_samples=5000), batch_size=args.batch_size, shuffle=False,
            num_workers=args.num_workers)
    else:
        # todo change this so that eval doesn't have to use the full training set
        train_eval_loader = train_loader

    return train_loader, train_eval_loader, val_loader, test_loader




def get_loaders_hard_neg_(args, dataset, splits, directed):
    train_data, val_data, test_data = splits['train'], splits['valid'], splits['test']
    # if args.model in {'ELPH', 'BUDDY'}:
    train_dataset, val_dataset, test_dataset = get_hashed_train_val_test_datasets(dataset, train_data, val_data,
                                                                                      test_data, args, directed)
    

    dl = DataLoader if args.model in {'ELPH', 'BUDDY'} else pygDataLoader
    train_loader = dl(train_dataset, batch_size=args.batch_size,
                      shuffle=True, num_workers=args.num_workers)
    # as the val and test edges are often sampled they also need to be shuffled
    # the citation2 dataset has specific negatives for each positive and so can't be shuffled
    
    val_loader = dl(val_dataset, batch_size=args.test_batch_size, shuffle=False,
                    num_workers=args.num_workers)
    
    test_loader = dl(test_dataset, batch_size=args.test_batch_size, shuffle=False,
                     num_workers=args.num_workers)
    
    train_eval_loader = train_loader

    return train_loader, train_eval_loader, val_loader, test_loader

from torch_geometric.loader import DataLoader as pygDataLoader

def get_loaders_hard_neg(args, splits):
    train_data, val_data, test_data = splits['train'], splits['valid'], splits['test']
    
    use_coalesce = True if args.data_name == 'ogbl-collab' else False
    root = f'./elph_'  # puoi cambiare ./ con il path dove vuoi salvare le cache

    # Estrai edge positivi e negativi da ciascun split
    pos_train_edge, neg_train_edge = get_pos_neg_edges(train_data)
    pos_val_edge, neg_val_edge = get_pos_neg_edges(val_data)
    pos_test_edge, neg_test_edge = get_pos_neg_edges(test_data)

    # Crea i dataset
    train_dataset = HashDataset(root, 'train', train_data, pos_train_edge, neg_train_edge, args, use_coalesce=use_coalesce)
    val_dataset = HashDataset(root, 'valid', val_data, pos_val_edge, neg_val_edge, args, use_coalesce=use_coalesce)
    test_dataset = HashDataset(root, 'test', test_data, pos_test_edge, neg_test_edge, args, use_coalesce=use_coalesce)

    # Scegli il DataLoader appropriato
    dl = DataLoader if args.model in {'ELPH', 'BUDDY'} else pygDataLoader

    train_loader = dl(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = dl(val_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers)
    test_loader = dl(test_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=args.num_workers)

    train_eval_loader = train_loader

    return train_loader, train_eval_loader, val_loader, test_loader


def get_pos_neg_edges(data, sample_frac=1):
    device = data.edge_index.device
    edge_index = data['edge_label_index'].to(device)
    labels = data['edge_label'].to(device)
    pos_edges = edge_index[:, labels == 1].t()
    neg_edges = edge_index[:, labels == 0].t()
    if sample_frac != 1:
        n_pos = pos_edges.shape[0]
        np.random.seed(123)
        perm = np.random.permutation(n_pos)
        perm = perm[:int(sample_frac * n_pos)]
        pos_edges = pos_edges[perm, :]
        neg_edges = neg_edges[perm, :]
    return pos_edges.to(device), neg_edges.to(device)