import os
import warnings
import json

import torch
from pytorch_lightning import (Trainer, seed_everything)
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import (WandbLogger, CSVLogger)

from equislt.args import parse_args_train
from equislt.methods import TRAIN_METHODS
from equislt.data import prepare_data


def main():
    args = parse_args_train()
    seed_everything(args.seed)

    # DATA
    device = None
    if args.gpus is not None:
        try:
            device = torch.device('cuda:' + str(args.gpus[0]))
        except TypeError:
            device = torch.device('cuda:' + str(args.gpus))
    train_loader, val_loader, test_loader, datametadata = prepare_data(
        args.dataset,
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        device=device,
    )
    for k, v in datametadata.items():
        setattr(args, k, v)
    print('Dataset metadata:\n', datametadata)

    # METHOD & ARCHITECTURE
    MethodClass = TRAIN_METHODS[args.method]
    model = MethodClass(**vars(args))

    callbacks = []

    # LOGGING
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    if args.wandb:
        os.makedirs(os.path.join(args.checkpoint_dir, 'wandb'), exist_ok=True)
        wandb_logger = WandbLogger(
            name=args.name,
            id=args.name,
            save_dir=args.checkpoint_dir,
            project=args.project,
            entity=args.entity,
            offline=args.offline,
            group=args.group,
            job_type='train'
        )
        wandb_logger.watch(model, log="gradients", log_freq=100)
        wandb_logger.log_hyperparams(args)

        # lr logging
        #  lr_monitor = LearningRateMonitor(logging_interval="epoch")
        #  callbacks.append(lr_monitor)
    else:
        csv_logger = CSVLogger(save_dir=args.checkpoint_dir, name='logs')
        csv_logger.log_hyperparams(args)

    # CHECKPOINTING
    if args.save_checkpoint:
        json_path = os.path.join(args.checkpoint_dir, "args.json")
        with open(json_path, 'w') as f:
            json.dump(vars(args), f, default=lambda o: "<not serializable>")

        ckpt = ModelCheckpoint(
            dirpath=args.checkpoint_dir,
            filename='best',
            save_last=True, save_top_k=1,
            monitor='val_acc', mode='max',
            auto_insert_metric_name=False,
            save_weights_only=False,
            every_n_epochs=args.checkpoint_frequency,
            save_on_train_epoch_end=False
        )
        callbacks.append(ckpt)

    # RESUMING
    ckpt_path = None
    if args.resume_from_checkpoint is not None:
        ckpt_path = args.resume_from_checkpoint
        del args.resume_from_checkpoint

    # TRAINING LOOP
    trainer = Trainer.from_argparse_args(
        args,
        logger=wandb_logger if args.wandb else csv_logger,
        callbacks=callbacks,
        enable_checkpointing=True,
    )

    trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)

    val_acc = float(trainer.callback_metrics["val_acc"])
    print(f"Training concluded with {val_acc:.3f} validation accuracy.")

    best_ckpt_dir = os.path.join(args.checkpoint_dir, 'best.ckpt')
    trainer.test(ckpt_path=best_ckpt_dir, dataloaders=test_loader)


if __name__ == "__main__":
    main()
