import argparse
import glob
import math
import os
import warnings

import torch as t
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from rich import print
from torch.utils.data import DataLoader, ConcatDataset
from tqdm import tqdm
from wandb.util import generate_id

from .args import add_train_args
from .dataset import Dataset
from .simulators import SIMULATOR
from .trainer import get_trainer

# For Procgen gym environment
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

seed_everything(1992)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_train_args(parser)
    args = parser.parse_args()

    args.batch_size = math.ceil(args.batch_size / max(1, t.cuda.device_count())) * max(1, t.cuda.device_count())
    args.environment = SIMULATOR.name
    args.wandb_id = generate_id()

    existing_versions = sorted([-1] + [int(f.split("_")[-1]) for f in glob.glob(f"runs/{args.exp_name}/version_*")])
    args.version = existing_versions[-1] + 1
    args.exp_path = f"runs/{args.exp_name}/version_{args.version}"
    os.makedirs(f"{args.exp_path}/checkpoints", exist_ok=True)
    os.makedirs(f"{args.exp_path}/logs", exist_ok=True)

    print(f"Loading training files...")
    dataloader = DataLoader(
        dataset=ConcatDataset([Dataset(file, args.n_steps) for file in tqdm(sorted(glob.glob(f"{args.data}*")))]),
        batch_size=math.ceil(args.batch_size / max(1, t.cuda.device_count())),
        num_workers=8,
        shuffle=True,
        pin_memory=True,
        persistent_workers=True,
    )

    print("----------->  Summary  <-----------")
    for k, v in sorted(args.__dict__.items()):
        print(f"{k.ljust(20)} : {v}")

    module = get_trainer(args)
    trainer = Trainer(
        max_steps=args.max_iterations,
        logger=WandbLogger(
            name=args.exp_name,
            save_dir=f"{args.exp_path}/",
            project=args.project,
            id=args.wandb_id,
            config=args.__dict__,
        ),
        strategy=DDPStrategy(find_unused_parameters=False),
        accelerator="gpu" if t.cuda.is_available() else "cpu",
        devices=max(1, t.cuda.device_count()),
        gradient_clip_val=10,
        gradient_clip_algorithm="value",
        log_every_n_steps=1,
        callbacks=[
            RichProgressBar(),
            ModelCheckpoint(
                dirpath=f"{args.exp_path}/checkpoints/",
                every_n_train_steps=args.checkpoint_frequency,
            ),
        ],
    )

    trainer.fit(module, dataloader)
    module.solver.save(f"{args.exp_path}/trained.model")
