import os.path as osp
import argparse
import math
import time
import random
import string

import torch
from torch.utils.data import DataLoader
import numpy as np
import wandb
import os
import pandas as pd

from games_dataset import Games
from indian_village_dataset import IndianVillageGames
from yelp import Yelp
from evaluation import get_linear_quadratic_optimization_roc_auc


"""python gnn_learn.py"""
"""python gnn_learn.py --graph_type indian_village --batch_size 1 --test_batch_size 1 --transformer_feedforward_dim 10"""

parser = argparse.ArgumentParser('Games')
parser.add_argument('--n_graphs', type=int, help='Number of graphs', default=1000)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--val_ratio', type=float, help='Ratio of validation set', default=0.05)
parser.add_argument('--test_ratio', type=float, help='Ratio of test set', default=0.1)
parser.add_argument('--n_nodes', type=int, help='Number of nodes', default=20)
parser.add_argument('--m', type=int, help='Barabasi-Albert parameter m', default=1)
parser.add_argument('--n_games', type=int, help='Number of games', default=50)
parser.add_argument('--hidden_dim', type=int, help='Dimension of node embeddings', default=50)
parser.add_argument('--transformer_feedforward_dim', type=int, help='Dimension of transformer feedforward dim', default=100)
parser.add_argument('--encoder_dropout', type=float, help='encoder dropout', default=0.)
parser.add_argument('--target_spectral_radius', type=float, help='Target spectral radius', default=0.2)
parser.add_argument('--n_epochs', type=int, help='Number of epochs', default=5000)
parser.add_argument('--patience', type=int, help='Early Stopping Patience', default=50)
parser.add_argument('--batch_size', type=int, help='Batch size', default=100)
parser.add_argument('--test_batch_size', type=int, help='Test batch size', default=10000)
parser.add_argument('--eval_every', type=int, help='Every how many epochs to run evaluation', default=1)
parser.add_argument('--num_inference_steps', type=int, help='Number of steps to carry out in inference optimization', default=0)
parser.add_argument('--encoder', type=str, help='Types of encoder', default="transformer", choices=["transformer", "mlp_on_nodes", "mlp_on_seq", "per_game_transformer", "column_transformer"])
parser.add_argument('--decoder', type=str, help='Types of decoder', default="mlp", choices=["dot_product", "cosine_similarity", "correlation_coefficient", "mlp"])
parser.add_argument('--alpha', type=float, help='Smoothness of marginal benefits', default=1.0)
parser.add_argument('--loss', type=str, help='Types of loss', default="bce", choices=["mse", "bce"])
parser.add_argument('--device', type=str, help='Device where to run the model', default="cuda:0")
parser.add_argument('--regenerate_data', action='store_true', help='Whether to regenerate the graphs')
parser.add_argument('--gamma', type=float, help='Coefficient B MSE', default=1.)
parser.add_argument('--inner_loop_lr', type=float, help='Inner loop lr.', default=0.01)
parser.add_argument('--eps', type=float, help='Inner loop tollerance.', default=1e-4)
parser.add_argument('--use_weighted_loss', action='store_true', help='Whether to use a weighted BCE or not.')
parser.add_argument('--permute_features', action='store_true', help='Whether to permute X and B')
parser.add_argument('--noise_std', type=float, help='B noise std.', default=0.)
parser.add_argument('--action_signal_to_noise_ratio', type=float, help='Signal-to-noise ration in synthetic actions', default=10)
parser.add_argument('--model_name', type=str, help='Model name.', default="GNN")
parser.add_argument('--transformer_num_layers', type=int, help='Number of transformer layers to use.', default=2)
parser.add_argument('--graph_type', type=str, help='Type of graph', default="barabasi_albert", choices=["barabasi_albert", "erdos_renyi", "watts_strogatz", "indian_village", "yelp"])
parser.add_argument('--yelp_dump_filename', type=str, help='Name of the file with the Yelp dataset.', default="")
parser.add_argument('--yelp_top_N_graphs', type=int, help='Number of graphs to use in Yelp dataset.', default=-1)
parser.add_argument('--game_type', type=str, help='Type of game', default="linear_quadratic", choices=["linear_quadratic", "variable_cost", "linear_influence", "barik_honorio"])
parser.add_argument('--cost_distribution', type=str, help='Type of distribution to use to sample node-wise costs.', default="normal", choices=["normal", "uniform"])
parser.add_argument('--model_to_train', type=str, help='Type of model to train.', default="nugget", choices=["nugget", "deep_graph"])
parser.add_argument('--num_deep_graph_eval_runs', type=int, help='Number of deep graph evaluation runs.', default=1)

args = parser.parse_args()

def run(args):
    wandb.init(project="network_games", config=args)

    if args.graph_type == 'indian_village':
        dataset = 'indian_networks'
    elif args.graph_type == 'yelp':
        dataset = os.path.join('Yelp/dumps', args.yelp_dump_filename)
    else:
        dataset = 'games'
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
    
    if args.graph_type == 'indian_village':
        dataset = IndianVillageGames(path)
        args.n_games = dataset[0]['X'].shape[1]
    elif args.graph_type == 'yelp':
        dataset = Yelp(path, top_N_graphs=args.yelp_top_N_graphs)
        args.n_games = dataset[0]['X'].shape[1]
    else:
        dataset = Games(path, n_graphs=args.n_graphs, n_nodes=args.n_nodes, m=args.m, n_games=args.n_games,
                    target_spectral_radius=args.target_spectral_radius, alpha=args.alpha,
                    signal_to_noise_ratio=args.action_signal_to_noise_ratio, game_type=args.game_type,
                    regenerate_data=args.regenerate_data, graph_type=args.graph_type, cost_distribution=args.cost_distribution)

    # Split datasets.
    train_ratio = 1 - args.val_ratio - args.test_ratio
    n_train_samples = math.floor(len(dataset) * train_ratio)
    n_val_samples = math.floor(len(dataset) * args.val_ratio)

    test_dataset = dataset[n_train_samples + n_val_samples:]

    hyperparams_df = pd.read_csv('lin_quad_opt_hyperparams.csv')
    hyperparams_df.loc[(hyperparams_df['graph_type'] == args.graph_type) & (hyperparams_df['alpha'] == args.alpha) & (hyperparams_df['target_spectral_radius'] == args.target_spectral_radius)].to_numpy()[0]
    beta = args.target_spectral_radius
    theta1 = 10 ** hyperparams_df["alpha1"].to_numpy()[0]
    theta2 = 10 ** hyperparams_df["alpha2"].to_numpy()[0]
    test_lin_quad_opt_roc_auc_mean, test_lin_quad_opt_roc_auc_std = get_linear_quadratic_optimization_roc_auc(test_dataset, beta, theta1, theta2, smooth=(args.alpha==1))
    
    results = {
        "lin_quad_opt_test_roc_auc_mean": test_lin_quad_opt_roc_auc_mean,
        "lin_quad_opt_test_roc_auc_std": test_lin_quad_opt_roc_auc_std
    }

    wandb.log(results)
    wandb.finish()

if __name__ == "__main__":
    run(args)