import os.path as osp
import argparse
import math
import time
import random
import string

import torch
import torch_geometric.transforms as T
from torch.utils.data import DataLoader
import numpy as np
import wandb
import os

from games_dataset import Games
#from indian_village_dataset import IndianVillageGames
#from yelp import Yelp
from utils import get_encoder, get_decoder, mask_diagonal, get_loss_weights#, permute_features
#from evaluation import eval_baseline, eval_everything, eval

n_graphs = 100
n_nodes = 20
m = 1
n_games = 50
target_spectral_radius = 0.8
alpha = 1
marginal_benefits_noise_variance = 0.1
game_type = "linear_quadratic"
graph = "erdos_renyi"
regenerate_data = True
dataset = Games('../data/tmp', n_graphs=n_graphs, n_nodes=n_nodes, m=m, n_games=n_games,
            target_spectral_radius=target_spectral_radius, alpha=alpha,
            marginal_benefits_noise_variance=marginal_benefits_noise_variance, game_type=game_type, transform=None,
            regenerate_data=regenerate_data, graph_type=graph, cost_distribution=None)