import os
import pathlib
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pl_bolts.datamodules import ImagenetDataModule
from models.lit_vit import LitVit


def cli_main(args):
    pl.seed_everything(args.seed)

    # Scale batch size in case of ddp:
    if args.accelerator == "ddp":
        args.batch_size = int(args.batch_size / args.gpus)

    imagenet_dm = ImagenetDataModule(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
    )

    trainer = pl.Trainer.from_argparse_args(args)

    # Avoid saving logger and callbacks as hparams
    args.logger = None
    args.checkpoint_callback = None

    model = LitVit(args)
    model.datamodule = imagenet_dm

    print(args)
    print(model)
    trainer.fit(model, imagenet_dm)
    trainer.test(model, datamodule=imagenet_dm)


def build_args():
    parser = ArgumentParser()

    # client arguments
    parser.add_argument(
        "--data_dir",
        type=str,
        help="Path to ImageNet data.",
    )

    # model arguments
    parser = LitVit.add_model_specific_args(parser)

    # trainer config
    parser = pl.Trainer.add_argparse_args(parser)
    parser.set_defaults(
        seed=42,  # random seed
        deterministic=True,  # makes things slower, but deterministic
        accelerator="ddp",
        max_epochs=40,
        dataset="imagenet",
    )

    args = parser.parse_args()

    # configure checkpointing in checkpoint_dir
    checkpoint_dir = pathlib.Path(args.default_root_dir) / "checkpoints"
    if not checkpoint_dir.exists():
        checkpoint_dir.mkdir(parents=True)

    args.checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=pathlib.Path(args.default_root_dir) / "checkpoints",
        save_top_k=-1,
        verbose=True,
        period=10,
    )

    # configure logger
    args.logger = pl.loggers.TensorBoardLogger(args.default_root_dir, name="")

    return args


def run_cli():
    args = build_args()
    # ---------------------
    # RUN TRAINING
    # ---------------------
    cli_main(args)


if __name__ == "__main__":
    run_cli()
