import os.path as osp
import argparse
import math
import time

import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric import utils
import torch_geometric.transforms as T
from torch.utils.data import DataLoader
import numpy as np
import wandb

from games_dataset import Games
from utils import compute_roc_auc_score, correlation_baseline_score, get_encoder, mask_diagonal
from evaluation import eval, eval_inference
from decoder import DotProductDecoder

"""python gnn/gnn_learn.py --regenerate_data --num_inference_steps 100 --n_graphs 1000"""

parser = argparse.ArgumentParser('Games')
parser.add_argument('--n_graphs', type=int, help='Number of graphs', default=100)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--val_ratio', type=float, help='Ratio of validation set', default=0.05)
parser.add_argument('--test_ratio', type=float, help='Ratio of test set', default=0.1)
parser.add_argument('--n_nodes', type=int, help='Number of nodes', default=10)
parser.add_argument('--m', type=int, help='Barabasi-Albert parameter m', default=2)
parser.add_argument('--n_games', type=int, help='Number of games', default=50)
parser.add_argument('--hidden_dim', type=int, help='Dimension of node embeddings', default=50)
parser.add_argument('--target_spectral_radius', type=float, help='Target spectral radius', default=0.5)
parser.add_argument('--n_epochs', type=int, help='Number of epochs', default=500)
parser.add_argument('--patience', type=int, help='Early Stopping Patience', default=5)
parser.add_argument('--batch_size', type=int, help='Batch size', default=100)
parser.add_argument('--test_batch_size', type=int, help='Test batch size', default=10000)
parser.add_argument('--eval_every', type=int, help='Every how many epochs to run evaluation', default=10)
parser.add_argument('--num_inference_steps', type=int, help='Number of steps to carry out in inference optimization', default=200)
parser.add_argument('--encoder', type=str, help='Types of marginal benefits', default="transformer", choices=["transformer", "mlp_on_nodes", "mlp_on_seq"])
parser.add_argument('--benefits', type=str, help='Types of marginal benefits', default="independent", choices=["independent", "homophilous"])
parser.add_argument('--loss', type=str, help='Types of loss', default="bce", choices=["mse", "bce"])
parser.add_argument('--regenerate_data', action='store_true', help='Whether to regenerate the graphs')
args = parser.parse_args()

wandb.init(project="network_games", config=args)

def eval_everything(epoch, epoch_loss, start):
    train_roc_auc = eval(model, train_eval_loader, device)
    val_roc_auc = eval(model, val_loader, device)
    test_roc_auc = eval(model, test_loader, device)

    print(f"Epoch {epoch + 1} --- Loss: {epoch_loss:.4f}, train ROC_AUC: {train_roc_auc:.4f}, val ROC_AUC: {val_roc_auc:.4f}, test ROC_AUC: {test_roc_auc:.4f}. It took {time.time() - start:.2f}s")
    wandb.log({"loss": epoch_loss, "train_roc_auc": train_roc_auc, "val_roc_auc": val_roc_auc, "test_roc_auc": test_roc_auc})

    if (epoch + 1) % args.eval_every == 0:
        start = time.time()
        inferred_train_roc_auc = eval_inference(model=model, data_loader=train_eval_loader, num_steps=args.num_inference_steps, device=device, name="train")
        inferred_val_roc_auc = eval_inference(model=model, data_loader=val_loader, num_steps=args.num_inference_steps, device=device, name="val")
        inferred_test_roc_auc = eval_inference(model=model, data_loader=test_loader, num_steps=args.num_inference_steps, device=device, name="test")
        print(f"Train infererred ROC_AUC: {inferred_train_roc_auc:.4f}, val ROC_AUC: {inferred_val_roc_auc:.4f}, test ROC_AUC: {inferred_test_roc_auc:.4f}. It took {time.time() - start:.2f}s")
        wandb.log({"inferred_train_roc_auc": inferred_train_roc_auc, "inferred_val_roc_auc": inferred_val_roc_auc, "inferred_test_roc_auc": inferred_test_roc_auc})


class Model(torch.nn.Module):
    def __init__(self, encoder_type, n_nodes, n_games, hidden_dim):
        super(Model, self).__init__()
        self.encoder = get_encoder(encoder_type, n_nodes, n_games, hidden_dim)
        self.decoder = DotProductDecoder()
        self.n_games = n_games
        self.n_nodes = n_nodes

    def forward(self, x, b):
        z = self.encoder(x, b)
        A = self.decoder(z)

        return A

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = 'games'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
dataset = Games(path, n_graphs=args.n_graphs, n_nodes=args.n_nodes, m=args.m, n_games=args.n_games, target_spectral_radius=args.target_spectral_radius, homophilous_marginal_benefits=args.benefits == "homophilous", transform=T.NormalizeFeatures(), regenerate_data=args.regenerate_data)
dataset = dataset[:args.n_graphs]

# Split datasets.
train_ratio = 1 - args.val_ratio - args.test_ratio
n_train_samples = math.floor(len(dataset) * train_ratio)
n_val_samples = math.floor(len(dataset) * args.val_ratio)

train_dataset = dataset[:n_train_samples]
val_dataset = dataset[n_train_samples:n_train_samples + n_val_samples]
test_dataset = dataset[n_train_samples + n_val_samples:]

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
train_eval_loader = DataLoader(train_dataset, batch_size=args.test_batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)

# Model
model = Model(encoder_type=args.encoder, n_nodes=args.n_nodes, n_games=args.n_games, hidden_dim=args.hidden_dim).to(device)
print(f"Model has {sum([p.numel() for p in model.parameters()])} parameters")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_critereon = torch.nn.BCEWithLogitsLoss() if args.loss == "bce" else torch.nn.MSELoss()
wandb.watch(model, loss_critereon, log="all", log_freq=10)

train_baseline_roc_auc = np.mean([correlation_baseline_score(data["x"], utils.to_dense_adj(data["edge_index"])) for data in train_dataset])
val_baseline_roc_auc = np.mean([correlation_baseline_score(data["x"], utils.to_dense_adj(data["edge_index"])) for data in val_dataset])
test_baseline_roc_auc = np.mean([correlation_baseline_score(data["x"], utils.to_dense_adj(data["edge_index"])) for data in test_dataset])
print(f"Correlation baseline --- train ROC_AUC:{train_baseline_roc_auc:.4f}, val ROC_AUC:{val_baseline_roc_auc:.4f}, test ROC_AUC:{test_baseline_roc_auc:.4f}")

# Evaluate model before training
eval_everything(epoch=-1, epoch_loss=0, start=0)
print("Starting training")

for epoch in range(args.n_epochs):
    start = time.time()
    epoch_loss = 0
    
    for data in train_loader:
        optimizer.zero_grad()
        x, b, edge_index, A = data["x"].to(device), data["b"].to(device), data["edge_index"].to(device), data["adj"].to(device)
        n_samples = x.shape[0]
        A_pred = model(x, b)
        
        # The dot product encoder will always produce high scores for the diagonal (dot product of a node with itself). But by construction, we 
        # do not have self-edges. Therefore, we mask out the diagonal (not doing it also causes instability in the training)
        loss = loss_critereon(mask_diagonal(A_pred).reshape(-1), mask_diagonal(A).reshape(-1))
        epoch_loss += loss
        loss.backward()
        optimizer.step()
    
    epoch_loss /= len(train_loader)

    eval_everything(epoch=epoch, epoch_loss=epoch_loss, start=start)

wandb.finish()