import os.path as osp
from random import randint
from argparse import ArgumentParser

from models import models
from pl_wrappers import wrappers
from data import data as datasets
from loaders import loaders

from torch_geometric import seed_everything
from torch_geometric.data import LightningNodeData

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import loggers as pl_loggers
import pytorch_lightning as pl

from datetime import datetime
from torchvision import transforms as T

import torch

DIR = osp.dirname(__file__)


def main(experiment, config, args):
    seed_everything(args.seed)

    dataset = datasets[args.dataset](root=osp.join(DIR,
                                                   args.root,
                                                   args.dataset))
    data = dataset[0]

    loader = loaders[config["sampler"]["name"]]
    train_transform = T.TrivialAugmentWide()

    train_loader = loader(data, shuffle=True, **config["sampler"]["train"], num_worker=32, transform=train_transform)
    val_loader = loader(data, shuffle=False, **config["sampler"]["val"], num_worker=32)
    test_loader = loader(data, shuffle=False, **config["sampler"]["test"], num_worker=32)

    if config["model"]["name"] == "Net":
        model_kwargs = {"num_features": dataset.num_features}
    else:
        model_kwargs = config["model"]["kwargs"]

    model = models[config["model"]["name"]](config["model"],
                                            dataset.num_classes,
                                            **model_kwargs)

    # Hack to get RoBERTa test results
    if args.test_from_weights is not None:
        model.load_state_dict(torch.load(args.test_from_weights))
        print("\n\n--------TEST MODEL LOADED-----------\n\n")

    pl_model = wrappers[config["lightning"]["name"]](model, **config["lightning"])

    time = datetime.now()
    version = f'{time.strftime("%Y%m%d-%H%M%S")}_{args.seed}'
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=osp.join(DIR,
                                                               config["trainer"]["dir"]),
                                             name=experiment,
                                             version=version)

    checkpoint_name = experiment + f'_{str(args.seed)}-{time.strftime("%Y%m%d-%H%M%S")}'
    trainer = pl.Trainer(
        callbacks=[
            ModelCheckpoint(
                monitor="val_acc",
                dirpath=osp.join(DIR, "../checkpoints", checkpoint_name),
                filename="{epoch:02d}-{val_acc:.3f}-{val_loss:.2f}",
                save_last=True,
                mode="max",
            ),
        ],
        accelerator=config["trainer"]["accelerator"],
        devices=config["trainer"]["devices"],
        logger=tb_logger,
        replace_sampler_ddp=False,
        strategy='ddp',
        max_epochs=config["trainer"]["max_epochs"],
        enable_progress_bar=True,
    )

    # Again more dodgy coding to test roBERTa
    if args.test_from_weights is not None:
        print("\n\nSKIPPING TRAINING TO TEST\n\n")
        trainer.test(pl_model, dataloaders=test_loader)
        quit()

    print(f'Running {expr_name} with seed value {args.seed}')
    print(f'Saving models to {osp.join(DIR, "../checkpoints", checkpoint_name)}')
    trainer.fit(pl_model, train_loader, val_loader)

    torch.save(
        pl_model.model.state_dict(), 
        osp.join(DIR, "../checkpoints", checkpoint_name, "weights")
    )

    best_model = pl_model.load_from_checkpoint(
        osp.join(DIR, "../checkpoints", checkpoint_name, "last.ckpt")
    )
    trainer.test(best_model, dataloaders=test_loader)


if __name__=='__main__':

    from sys import argv
    import yaml
    import argparse
    import os.path as osp

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True, help="Config file")
    parser.add_argument('--dataset', required=True, help="Dataset to run on")
    parser.add_argument('--root', required=True, help="Root directory for dataset")
    parser.add_argument('--seed', required=True, type=int, help="Seed to run within")
    parser.add_argument('--test_from_weights', default=None, help="Test on provided torch weights (for roBERTa)")
    args = parser.parse_args()

    with open(osp.abspath(args.config), 'r') as config_file:
        config = yaml.safe_load(config_file)
        filename = args.config.split('/')[-1]
        expr_name = args.dataset + "." + '.'.join(filename.split('.')[:-1])
        main(expr_name, config, args)

