import csv
from sklearn.preprocessing import OneHotEncoder
import pickle

import torch
import numpy as np

class Yelp(torch.utils.data.Dataset):
    def __init__(self, path, top_N_graphs=-1, seed=0):
        super(Yelp, self).__init__()
        
        self.path = path
        data = []

        with open(path, 'rb') as f:
            dataset, perm = pickle.load(f)

        if top_N_graphs != -1:
            idx_samples_to_use = perm[:top_N_graphs]
            np.random.RandomState(seed=seed).shuffle(idx_samples_to_use)  # np.random.shuffle with a specific seed for reproducibility
            dataset = [dataset[k] for k in idx_samples_to_use]

        for sample in dataset:
            adj, x = sample
            adj = adj.toarray()  # adj is stored as a sparse matrix

            adj_correlation = torch.Tensor(np.corrcoef(x)).float()
            adj_correlation = adj_correlation.fill_diagonal_(0.0)

            data.append(
                {
                    "X": torch.FloatTensor(x),
                    "A": torch.FloatTensor(adj),
                    "A_correlation": adj_correlation,
                    "B": torch.zeros_like(torch.FloatTensor(x)),
                    "beta": 0
                }
            )

        self.data_list = data

    def __getitem__(self, idx):
        return self.data_list[idx]

    def __len__(self):
        return len(self.data_list)
