import os
import pathlib
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from models.lit_vit import LitVit


def cli_main(args):
    pl.seed_everything(args.seed)

    train_transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.RandomCrop(32, padding=4),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.ToTensor(),
            cifar10_normalization(),
        ]
    )

    test_transforms = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            cifar10_normalization(),
        ]
    )

    # Scale batch size in case of ddp:
    if args.accelerator == "ddp":
        args.batch_size = int(args.batch_size / args.gpus)

    cifar10_dm = CIFAR10DataModule(
        batch_size=args.batch_size,
        train_transforms=train_transforms,
        test_transforms=test_transforms,
        val_transforms=test_transforms,
    )

    trainer = pl.Trainer.from_argparse_args(args)

    # Avoid saving logger and callbacks as hparams
    args.logger = None
    args.checkpoint_callback = None

    model = LitVit(args)
    model.datamodule = cifar10_dm

    print(args)
    trainer.fit(model, cifar10_dm)
    trainer.test(model, datamodule=cifar10_dm)


def build_args():
    parser = ArgumentParser()

    # model arguments
    parser = LitVit.add_model_specific_args(parser)

    # trainer config
    parser = pl.Trainer.add_argparse_args(parser)
    parser.set_defaults(
        seed=42,  # random seed
        deterministic=True,  # makes things slower, but deterministic
        accelerator="ddp",
        max_epochs=40,
        dataset="cifar10",
    )

    args = parser.parse_args()

    # configure checkpointing in checkpoint_dir
    checkpoint_dir = pathlib.Path(args.default_root_dir) / "checkpoints"
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True)

    args.checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=pathlib.Path(args.default_root_dir) / "checkpoints",
        save_top_k=-1,
        verbose=True,
        period=10,
    )

    # configure logger
    args.logger = pl.loggers.TensorBoardLogger(args.default_root_dir, name="")

    return args


def run_cli():
    args = build_args()
    # ---------------------
    # RUN TRAINING
    # ---------------------
    cli_main(args)


if __name__ == "__main__":
    run_cli()
