import os
import argparse
import logging
import numpy as np

import wandb
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelSummary, ModelCheckpoint

from src.utils import use_best_hyperparams, get_available_accelerator
from src.data_loading import get_dataset, get_dataset_split
from src.full_batch.model import get_model, LightingFullBatchModelWrapper
from src.full_batch.dataset import FullBatchGraphDataset
from src.arguments import args
    

def run(args):
    torch.manual_seed(0)
    wandb.init(project="scalable_graff", entity="graph_neural_diffusion", config=args)

    dataset, evaluator = get_dataset(name=args.dataset, root_dir=args.dataset_directory, undirected=args.undirected, self_loops=args.self_loops, transpose=args.transpose)
    data = dataset.data

    val_accs, test_accs = [], []
    for num_run in range(args.num_runs):
        train_mask, val_mask, test_mask = get_dataset_split(args.dataset, data, args.dataset_directory, num_run)
        
        args.num_features, args.num_nodes, args.num_classes = data.num_features, data.num_nodes, dataset.num_classes
        model = get_model(args)
        lit_model = LightingFullBatchModelWrapper(model=model, lr=args.lr, weight_decay=args.weight_decay, evaluator=evaluator, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, metric=args.metric)
        early_stopping_callback = EarlyStopping(monitor="val_acc", mode="max", patience=args.patience)
        model_summary_callback = ModelSummary(max_depth=-1)
        if not os.path.exists(f"{args.checkpoint_directory}/"):
            os.mkdir(f"{args.checkpoint_directory}/")
        model_checkpoint_callback = ModelCheckpoint(
            monitor="val_acc",
            mode="max",
            dirpath=f"{args.checkpoint_directory}/",
            filename=str(wandb.run.id),
        )
        trainer = pl.Trainer(
            logger=WandbLogger(project="scalable_gnn") if args.num_runs == 1 else None,
            log_every_n_steps=1,
            max_epochs=args.num_epochs,
            callbacks=[
                early_stopping_callback,
                model_summary_callback,
                model_checkpoint_callback,
            ],
            profiler="simple" if args.profiler else None,
            accelerator=get_available_accelerator(),
            devices=[args.gpu_idx],
        )
        train_dataset = FullBatchGraphDataset(data)
        loader = DataLoader(train_dataset, batch_size=1, collate_fn=lambda batch: batch[0])
        trainer.fit(model=lit_model, train_dataloaders=loader)

        val_acc = model_checkpoint_callback.best_model_score.item()  # Best val accuracy
        test_acc = trainer.test(ckpt_path="best", dataloaders=loader)[0]["test_acc"]
        wandb.log({"run_test_acc": test_acc, "run_val_acc": val_acc})
        test_accs.append(test_acc)
        val_accs.append(val_acc)

    results = {
        "test_acc_mean": np.mean(test_accs),
        "test_acc_std": np.std(test_accs),
        "val_acc_mean": np.mean(val_accs),
        "val_acc_std": np.std(val_accs),
    }
    wandb.log(results)
    if args.model == "gnn":
        wandb.log({"alpha": model.alpha.item()})
    
    wandb.finish()

    print(f"Test Acc: {np.mean(test_accs)} +- {np.std(test_accs)}")
    return results


if __name__ == "__main__":
    args = use_best_hyperparams(args, args.dataset) if args.use_best_hyperparams else args
    run(args)
