import os.path as osp
import argparse
import math
import time
import random
import string

import torch
from torch.utils.data import DataLoader
import numpy as np
import wandb
import os

from games_dataset import Games
from indian_village_dataset import IndianVillageGames
from yelp import Yelp
from utils import get_encoder, get_decoder, mask_diagonal, get_loss_weights, permute_features
from evaluation import eval_baseline, eval_everything, eval
from deep_graph import DeepGraph, permute_node_ordering_and_compute_covariance_matrix


"""python gnn_learn.py"""
"""python gnn_learn.py --graph_type indian_village --batch_size 1 --test_batch_size 1 --transformer_feedforward_dim 10"""

parser = argparse.ArgumentParser('Games')
parser.add_argument('--n_graphs', type=int, help='Number of graphs', default=1000)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--val_ratio', type=float, help='Ratio of validation set', default=0.05)
parser.add_argument('--test_ratio', type=float, help='Ratio of test set', default=0.1)
parser.add_argument('--n_nodes', type=int, help='Number of nodes', default=20)
parser.add_argument('--m', type=int, help='Barabasi-Albert parameter m', default=1)
parser.add_argument('--n_games', type=int, help='Number of games', default=50)
parser.add_argument('--hidden_dim', type=int, help='Dimension of node embeddings', default=50)
parser.add_argument('--transformer_feedforward_dim', type=int, help='Dimension of transformer feedforward dim', default=100)
parser.add_argument('--encoder_dropout', type=float, help='encoder dropout', default=0.)
parser.add_argument('--target_spectral_radius', type=float, help='Target spectral radius', default=0.2)
parser.add_argument('--n_epochs', type=int, help='Number of epochs', default=5000)
parser.add_argument('--patience', type=int, help='Early Stopping Patience', default=50)
parser.add_argument('--batch_size', type=int, help='Batch size', default=100)
parser.add_argument('--test_batch_size', type=int, help='Test batch size', default=10000)
parser.add_argument('--eval_every', type=int, help='Every how many epochs to run evaluation', default=1)
parser.add_argument('--num_inference_steps', type=int, help='Number of steps to carry out in inference optimization', default=0)
parser.add_argument('--encoder', type=str, help='Types of encoder', default="transformer", choices=["transformer", "mlp_on_nodes", "mlp_on_seq", "per_game_transformer", "column_transformer"])
parser.add_argument('--decoder', type=str, help='Types of decoder', default="mlp", choices=["dot_product", "cosine_similarity", "correlation_coefficient", "mlp"])
parser.add_argument('--alpha', type=float, help='Smoothness of marginal benefits', default=1.0)
parser.add_argument('--loss', type=str, help='Types of loss', default="bce", choices=["mse", "bce"])
parser.add_argument('--device', type=str, help='Device where to run the model', default="cuda:0")
parser.add_argument('--regenerate_data', action='store_true', help='Whether to regenerate the graphs')
parser.add_argument('--gamma', type=float, help='Coefficient B MSE', default=1.)
parser.add_argument('--inner_loop_lr', type=float, help='Inner loop lr.', default=0.01)
parser.add_argument('--eps', type=float, help='Inner loop tollerance.', default=1e-4)
parser.add_argument('--use_weighted_loss', action='store_true', help='Whether to use a weighted BCE or not.')
parser.add_argument('--permute_features', action='store_true', help='Whether to permute X and B')
parser.add_argument('--noise_std', type=float, help='B noise std.', default=0.)
parser.add_argument('--action_signal_to_noise_ratio', type=float, help='Signal-to-noise ration in synthetic actions', default=10)
parser.add_argument('--model_name', type=str, help='Model name.', default="GNN")
parser.add_argument('--transformer_num_layers', type=int, help='Number of transformer layers to use.', default=2)
parser.add_argument('--graph_type', type=str, help='Type of graph', default="barabasi_albert", choices=["barabasi_albert", "erdos_renyi", "watts_strogatz", "indian_village", "yelp"])
parser.add_argument('--yelp_dump_filename', type=str, help='Name of the file with the Yelp dataset.', default="")
parser.add_argument('--yelp_top_N_graphs', type=int, help='Number of graphs to use in Yelp dataset.', default=-1)
parser.add_argument('--game_type', type=str, help='Type of game', default="linear_quadratic", choices=["linear_quadratic", "variable_cost", "linear_influence", "barik_honorio"])
parser.add_argument('--cost_distribution', type=str, help='Type of distribution to use to sample node-wise costs.', default="normal", choices=["normal", "uniform"])
parser.add_argument('--model_to_train', type=str, help='Type of model to train.', default="nugget", choices=["nugget", "deep_graph"])
parser.add_argument('--num_deep_graph_eval_runs', type=int, help='Number of deep graph evaluation runs.', default=20)

args = parser.parse_args()

class Model(torch.nn.Module):
    def __init__(self, encoder_type, decoder_type, n_nodes, n_games, hidden_dim, dropout=0., transformer_num_layers=1, transformer_feedforward_dim=100):
        super(Model, self).__init__()
        self.encoder = get_encoder(encoder_type, n_nodes, n_games, hidden_dim, dropout, transformer_num_layers, transformer_feedforward_dim)
        self.decoder = get_decoder(decoder_type, hidden_dim)

    def forward(self, X):
        Z = self.encoder(X)
        A = self.decoder(Z)

        return A

def run(args):
    wandb.init(project="network_games", config=args)
    model_name = ''.join(random.choices(string.ascii_uppercase + string.digits, k=15))
    model_path = f"../data/models/{model_name}.pt"
    device = torch.device(args.device)

    if args.graph_type == 'indian_village':
        dataset = 'indian_networks'
    elif args.graph_type == 'yelp':
        dataset = os.path.join('Yelp/dumps', args.yelp_dump_filename)
    else:
        dataset = 'games'
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
    
    if args.graph_type == 'indian_village':
        dataset = IndianVillageGames(path)
        args.n_games = dataset[0]['X'].shape[1]
    elif args.graph_type == 'yelp':
        dataset = Yelp(path, top_N_graphs=args.yelp_top_N_graphs)
        args.n_games = dataset[0]['X'].shape[1]
    else:
        dataset = Games(path, n_graphs=args.n_graphs, n_nodes=args.n_nodes, m=args.m, n_games=args.n_games,
                    target_spectral_radius=args.target_spectral_radius, alpha=args.alpha,
                    signal_to_noise_ratio=args.action_signal_to_noise_ratio, game_type=args.game_type,
                    regenerate_data=args.regenerate_data, graph_type=args.graph_type, cost_distribution=args.cost_distribution)

    # Split datasets.
    train_ratio = 1 - args.val_ratio - args.test_ratio
    n_train_samples = math.floor(len(dataset) * train_ratio)
    n_val_samples = math.floor(len(dataset) * args.val_ratio)

    train_dataset = dataset[:n_train_samples]
    val_dataset = dataset[n_train_samples:n_train_samples + n_val_samples]
    test_dataset = dataset[n_train_samples + n_val_samples:]

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    train_eval_loader = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)
    
    # Model
    if args.model_to_train == 'nugget':
        model = Model(encoder_type=args.encoder, decoder_type=args.decoder, n_nodes=args.n_nodes, n_games=args.n_games,
                      hidden_dim=args.hidden_dim, dropout=args.encoder_dropout,
                      transformer_num_layers=args.transformer_num_layers, transformer_feedforward_dim=args.transformer_feedforward_dim).to(device)
    elif args.model_to_train == 'deep_graph':
        max_num_nodes = 0
        for data in dataset:
            if data["A"].shape[0] > max_num_nodes:
                max_num_nodes = data["A"].shape[0]

        model = DeepGraph(max_num_nodes, 50, kernel_size=3).to(device)
    print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
    optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)

    pos_weight = get_loss_weights(args.use_weighted_loss, train_loader)
    loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) if args.loss == "bce" else torch.nn.MSELoss()
    wandb.watch(model, loss_criterion, log="all", log_freq=10)
    
    baseline_results = eval_baseline(train_dataset, val_dataset, test_dataset)
    
    print("Starting training")

    val_roc_aucs, test_roc_aucs = [], []
    for epoch in range(args.n_epochs):
        start = time.time()
        model.train()
        epoch_loss = 0

        for data in train_loader:
            optimizer.zero_grad()
            X, B, A = data["X"].to(device), data["B"].to(device),  data["A"].to(device)

            if args.permute_features:
                X, B = permute_features(X, B)

            if args.model_to_train == 'deep_graph':
                X, C, A, perm = permute_node_ordering_and_compute_covariance_matrix(X, A)
                A_pred = model(C)
            elif args.model_to_train == 'nugget':
                A_pred = model(X)
            # The dot product encoder will always produce high scores for the diagonal (dot product of a node with itself). But by construction, we 
            # do not have self-edges. Therefore, we mask out the diagonal (not doing it also causes instability in the training)
            loss = loss_criterion(mask_diagonal(A_pred).reshape(-1), mask_diagonal(A).reshape(-1))
            
            epoch_loss += loss
            loss.backward()
            optimizer.step()
        
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch + 1} --- Loss: {epoch_loss:.4f}. It took {time.time() - start:.2f}s")

        train_roc_auc, val_roc_auc, test_roc_auc, *_ = eval_everything(model=model, train_eval_loader=train_eval_loader,
                                                                       val_loader=val_loader, test_loader=test_loader,
                                                                       device=device, args=args)
        
        # Store best model
        if len(val_roc_aucs) == 0 or val_roc_auc > max(val_roc_aucs):
            torch.save(model.state_dict(), model_path)
        
        val_roc_aucs.append(val_roc_auc)
        test_roc_aucs.append(test_roc_auc)
        print(f"Train ROC_AUC: {train_roc_auc:.4f}, val ROC_AUC: {val_roc_auc:.4f}, test ROC_AUC: {test_roc_auc:.4f}. It took {time.time() - start:.2f}s")
        
        model_results = {"loss": epoch_loss, "train_roc_auc": train_roc_auc, "val_roc_auc": val_roc_auc, "test_roc_auc": test_roc_auc, "epoch": epoch}
        wandb.log(model_results)

        if epoch > args.patience and max(val_roc_aucs[-args.patience:]) < max(val_roc_aucs):
            print("Early stopping")
            break
    
    print("Evaluating best model")
    model.load_state_dict(torch.load(model_path))
    model.eval()
    final_test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    model_test_roc_auc_mean, model_test_roc_auc_std = eval(model=model, data_loader=final_test_loader, device=device, args=args)
    model_results = {"model_test_roc_auc_mean": model_test_roc_auc_mean, "model_test_roc_auc_std": model_test_roc_auc_std}
    wandb.log({**model_results, **baseline_results})

    print(f"Baseline test ROC_AUC: {baseline_results['correlation_test_roc_auc_mean']:.4f}+-{baseline_results['correlation_test_roc_auc_std']:.4f}")
    print(f"Model test ROC_AUC   : {model_test_roc_auc_mean:.4f}+-{model_test_roc_auc_std:.4f}")
    
    wandb.finish()

if __name__ == "__main__":
    if args.graph_type not in ["indian_village", "yelp"]:
        args.permute_features = True
    
    run(args)