import os.path as osp
import argparse
import math
import time
import random
import string

import torch
from torch.utils.data import DataLoader
import numpy as np
import wandb
import os

from games_dataset import Games
from indian_village_dataset import IndianVillageGames
from yelp import Yelp
from utils import get_encoder, get_decoder, mask_diagonal, get_loss_weights, permute_features
from evaluation import eval_baseline, eval_everything, eval, get_barik_honorio_roc_auc
from deep_graph import DeepGraph, permute_node_ordering_and_compute_covariance_matrix

"""python gnn_learn_all.py"""
"""python gnn_learn_all.py --graph_type indian_village --batch_size 1 --test_batch_size 1"""

parser = argparse.ArgumentParser('Games')
parser.add_argument('--n_graphs', type=int, help='Number of graphs', default=1000)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--val_ratio', type=float, help='Ratio of validation set', default=0.05)
parser.add_argument('--test_ratio', type=float, help='Ratio of test set', default=0.1)
parser.add_argument('--n_nodes', type=int, help='Number of nodes', default=20)
parser.add_argument('--m', type=int, help='Barabasi-Albert parameter m', default=1)
parser.add_argument('--n_games', type=int, help='Number of games', default=50)
parser.add_argument('--hidden_dim', type=int, help='Dimension of node embeddings', default=50)
parser.add_argument('--transformer_feedforward_dim', type=int, help='Dimension of transformer feedforward dim', default=100)
parser.add_argument('--encoder_dropout', type=float, help='encoder dropout', default=0.)
parser.add_argument('--target_spectral_radius', type=float, help='Target spectral radius', default=0.2)
parser.add_argument('--n_epochs', type=int, help='Number of epochs', default=5000)
parser.add_argument('--patience', type=int, help='Early Stopping Patience', default=50)
parser.add_argument('--batch_size', type=int, help='Batch size', default=100)
parser.add_argument('--test_batch_size', type=int, help='Test batch size', default=10000)
parser.add_argument('--eval_every', type=int, help='Every how many epochs to run evaluation', default=1)
parser.add_argument('--num_inference_steps', type=int, help='Number of steps to carry out in inference optimization', default=0)
parser.add_argument('--encoder', type=str, help='Types of encoder', default="per_game_transformer", choices=["transformer", "mlp_on_nodes", "mlp_on_seq", "per_game_transformer", "column_transformer"])
parser.add_argument('--decoder', type=str, help='Types of decoder', default="mlp", choices=["dot_product", "cosine_similarity", "correlation_coefficient", "mlp"])
parser.add_argument('--alpha', type=float, help='Smoothness of marginal benefits', default=1.0)
parser.add_argument('--loss', type=str, help='Types of loss', default="bce", choices=["mse", "bce"])
parser.add_argument('--device', type=str, help='Device where to run the model', default="cuda:0")
parser.add_argument('--regenerate_data', action='store_true', help='Whether to regenerate the graphs')
parser.add_argument('--gamma', type=float, help='Coefficient B MSE', default=1.)
parser.add_argument('--inner_loop_lr', type=float, help='Inner loop lr.', default=0.01)
parser.add_argument('--eps', type=float, help='Inner loop tollerance.', default=1e-4)
parser.add_argument('--use_weighted_loss', action='store_true', help='Whether to use a weighted BCE or not.')
parser.add_argument('--permute_features', action='store_true', help='Whether to permute X and B')
parser.add_argument('--noise_std', type=float, help='B noise std.', default=0.)
parser.add_argument('--action_signal_to_noise_ratio', type=float, help='Signal-to-noise ration in synthetic actions', default=10)
parser.add_argument('--model_name', type=str, help='Model name.', default="GNN")
parser.add_argument('--transformer_num_layers', type=int, help='Number of transformer layers to use.', default=2)
parser.add_argument('--graph_type', type=str, help='Type of graph', default="barabasi_albert", choices=["barabasi_albert", "erdos_renyi", "watts_strogatz", "indian_village", "yelp"])
parser.add_argument('--yelp_dump_filename', type=str, help='Name of the file with the Yelp dataset.', default="")
parser.add_argument('--yelp_top_N_graphs', type=int, help='Number of graphs to use in Yelp dataset.', default=-1)
parser.add_argument('--game_type', type=str, help='Type of game', default="linear_quadratic", choices=["linear_quadratic", "variable_cost", "linear_influence", "barik_honorio"])
parser.add_argument('--cost_distribution', type=str, help='Type of distribution to use to sample node-wise costs.', default="normal", choices=["normal", "uniform"])
parser.add_argument('--model_to_train', type=str, help='Type of model to train.', default="nugget", choices=["nugget", "deep_graph"])
parser.add_argument('--num_deep_graph_eval_runs', type=int, help='Number of deep graph evaluation runs.', default=1)

args = parser.parse_args()

class Model(torch.nn.Module):
    def __init__(self, encoder_type, decoder_type, n_nodes, n_games, hidden_dim, dropout=0., transformer_num_layers=1, transformer_feedforward_dim=100):
        super(Model, self).__init__()
        self.encoder = get_encoder(encoder_type, n_nodes, n_games, hidden_dim, dropout, transformer_num_layers, transformer_feedforward_dim)
        self.decoder = get_decoder(decoder_type, hidden_dim, permutation_invariant=encoder_type=='per_game_transformer')

    def forward(self, X):
        Z = self.encoder(X)
        A = self.decoder(Z)

        return A

def run(args):
    wandb.init(project="network_games", config=args)
    model_name = ''.join(random.choices(string.ascii_uppercase + string.digits, k=15))
    model_path = f"../data/models/{model_name}.pt"
    device = torch.device(args.device)

    if args.graph_type == 'indian_village':
        dataset = 'indian_networks'
    elif args.graph_type == 'yelp':
        dataset = os.path.join('Yelp/dumps', args.yelp_dump_filename)
    else:
        dataset = 'games'
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
    
    if args.graph_type == 'indian_village':
        dataset = IndianVillageGames(path)
        args.n_games = dataset[0]['X'].shape[1]
    elif args.graph_type == 'yelp':
        dataset = Yelp(path, top_N_graphs=args.yelp_top_N_graphs)
        args.n_games = dataset[0]['X'].shape[1]
    else:
        dataset = Games(path, n_graphs=args.n_graphs, n_nodes=args.n_nodes, m=args.m, n_games=args.n_games,
                    target_spectral_radius=args.target_spectral_radius, alpha=args.alpha,
                    signal_to_noise_ratio=args.action_signal_to_noise_ratio, game_type=args.game_type,
                    regenerate_data=args.regenerate_data, graph_type=args.graph_type, cost_distribution=args.cost_distribution)

    # Split datasets.
    train_ratio = 1 - args.val_ratio - args.test_ratio
    n_train_samples = math.floor(len(dataset) * train_ratio)
    n_val_samples = math.floor(len(dataset) * args.val_ratio)

    train_dataset = dataset[:n_train_samples]
    val_dataset = dataset[n_train_samples:n_train_samples + n_val_samples]
    test_dataset = dataset[n_train_samples + n_val_samples:]

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    train_eval_loader = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)

    baseline_results = eval_baseline(train_dataset, val_dataset, test_dataset, args.game_type)
    models_results = {}

    # Model
    for model_to_train in ['nugget', 'deep_graph']:
        args.model_to_train = model_to_train
        print(f'Training {args.model_to_train}')

        if args.model_to_train == 'deep_graph':
            max_num_nodes = 0
            for data in dataset:
                if data["A"].shape[0] > max_num_nodes:
                    max_num_nodes = data["A"].shape[0]

            model = DeepGraph(max_num_nodes, 50, kernel_size=3).to(device)
        else:
            model = Model(encoder_type=args.encoder, decoder_type=args.decoder, n_nodes=args.n_nodes,
                          n_games=args.n_games,
                          hidden_dim=args.hidden_dim, dropout=args.encoder_dropout,
                          transformer_num_layers=args.transformer_num_layers,
                          transformer_feedforward_dim=args.transformer_feedforward_dim).to(device)

        print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
        optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)

        pos_weight = get_loss_weights(args.use_weighted_loss, train_loader)
        loss_criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight.to(device)) if args.loss == "bce" else torch.nn.MSELoss()
        wandb.watch(model, loss_criterion, log="all", log_freq=10)

        print("Starting training")

        val_roc_aucs, test_roc_aucs = [], []
        for epoch in range(args.n_epochs):
            start = time.time()
            model.train()
            epoch_loss = 0

            for data in train_loader:
                optimizer.zero_grad()
                X, B, A = data["X"].to(device), data["B"].to(device),  data["A"].to(device)

                if args.permute_features:
                    X, B = permute_features(X, B)

                if args.model_to_train == 'deep_graph':
                    X, C, A, perm = permute_node_ordering_and_compute_covariance_matrix(X, A)
                    A_pred = model(C)
                else:
                    A_pred = model(X)
                # The dot product encoder will always produce high scores for the diagonal (dot product of a node with itself). But by construction, we
                # do not have self-edges. Therefore, we mask out the diagonal (not doing it also causes instability in the training)
                loss = loss_criterion(mask_diagonal(A_pred).reshape(-1), mask_diagonal(A).reshape(-1))

                epoch_loss += loss
                loss.backward()
                optimizer.step()

            epoch_loss /= len(train_loader)
            print(f"Epoch {epoch + 1} --- Loss: {epoch_loss:.4f}. It took {time.time() - start:.2f}s")

            train_roc_auc, val_roc_auc, test_roc_auc, *_ = eval_everything(model=model, train_eval_loader=train_eval_loader,
                                                                           val_loader=val_loader, test_loader=test_loader,
                                                                           device=device, args=args)

            # Store best model
            if len(val_roc_aucs) == 0 or val_roc_auc > max(val_roc_aucs):
                torch.save(model.state_dict(), model_path)

            val_roc_aucs.append(val_roc_auc)
            test_roc_aucs.append(test_roc_auc)
            print(f"Train ROC_AUC: {train_roc_auc:.4f}, val ROC_AUC: {val_roc_auc:.4f}, test ROC_AUC: {test_roc_auc:.4f}. It took {time.time() - start:.2f}s")

            model_results = {f"loss_{args.model_to_train}": epoch_loss, f"train_roc_auc_{args.model_to_train}": train_roc_auc, f"val_roc_auc_{args.model_to_train}": val_roc_auc, f"test_roc_auc_{args.model_to_train}": test_roc_auc, f"epoch_{args.model_to_train}": epoch}
            wandb.log(model_results)

            if epoch > args.patience and max(val_roc_aucs[-args.patience:]) < max(val_roc_aucs):
                print("Early stopping")
                break

        print("Evaluating best model")
        model.load_state_dict(torch.load(model_path))
        model.eval()
        final_test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
        model_test_roc_auc_mean, model_test_roc_auc_std = eval(model=model, data_loader=final_test_loader, device=device, args=args)
        models_results[model_to_train] = {f"{args.model_to_train}_test_roc_auc_mean": model_test_roc_auc_mean, f"{args.model_to_train}_test_roc_auc_std": model_test_roc_auc_std}

    wandb.log({**models_results["nugget"], **models_results["deep_graph"], **baseline_results})
    wandb.finish()

if __name__ == "__main__":
    if args.graph_type not in ["indian_village", "yelp"]:
        args.permute_features = True
    if args.encoder == 'per_game_transformer':
        args.hidden_dim = 100
    
    run(args)