import os
import torch
import pytorch_lightning as pl
from transformers import logging as hf_logging
from mStream.callback import build_callbacks
from mStream.config import parse_args, instantiate_from_config
from mGPT.data.build_data import build_data
from mGPT.models.build_model import build_model
from mGPT.utils.logger import setup_logger
from mGPT.utils.load_checkpoint import (
    load_pretrained,
    load_pretrained_lora,
    load_pretrained_vae,
)
from collections import OrderedDict


def main():
    # Parse config file
    cfg = parse_args(phase="train")

    # Set up logger
    logger = setup_logger(cfg, phase="train")

    # Set random seed for reproducibility
    pl.seed_everything(cfg.SEED_VALUE)

    # Set logging verbosity
    if not cfg.DEBUG:
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
        hf_logging.set_verbosity_error()
    torch.multiprocessing.set_sharing_strategy("file_system")

    # Initialize metric loggers
    pl_loggers = []
    for logger_name in cfg.LOGGER.TYPE:
        if logger_name == "tensorboard" or cfg.LOGGER.WANDB.params.project:
            pl_logger = instantiate_from_config(
                eval(f"cfg.LOGGER.{logger_name.upper()}")
            )
            pl_loggers.append(pl_logger)

    # Initialize callbacks
    callbacks = build_callbacks(cfg, phase="train")
    logger.info("Callbacks initialized")

    # Initialize dataset
    datamodule = build_data(cfg)
    logger.info(f"Data {datamodule.name} initialized")

    # Initialize model
    model = build_model(cfg, datamodule)
    logger.debug(model)
    logger.info(f"Model {cfg.model.target} initialized")

    # Set precision based on CUDA support and training stage
    precision = None
    if torch.cuda.is_bf16_supported() and cfg.TRAIN.STAGE != "vae":
        # Uncomment the line below to set precision
        # precision = "bf16-true"
        pass

    # Determine training strategy
    if len(cfg.DEVICE) > 1:
        if cfg.TRAIN.STAGE == "vae":
            # strategy = "ddp_find_unused_parameters_true"
            strategy = "ddp"
        else:
            strategy = "ddp"

        # if cfg.model.params.rec_only:
        #     strategy = "ddp_find_unused_parameters_true"
    else:
        strategy = "auto"

    # Initialize PyTorch Lightning trainer
    trainer = pl.Trainer(
        default_root_dir=cfg.FOLDER_EXP,
        max_epochs=cfg.TRAIN.END_EPOCH,
        precision=precision,
        logger=pl_loggers,
        callbacks=callbacks,
        check_val_every_n_epoch=cfg.LOGGER.VAL_EVERY_STEPS,
        accelerator=cfg.ACCELERATOR,
        devices=cfg.DEVICE,
        num_nodes=cfg.NUM_NODES,
        strategy=strategy,
        log_every_n_steps=1,
        benchmark=False,
        deterministic=False,
        gradient_clip_val=cfg.TRAIN.GRADIENT_CLIP,
    )
    logger.info("Trainer initialized")

    # Load pretrained models
    load_pretrained(cfg, model)
    load_pretrained_vae(cfg, model)

    # Initialize LoRA
    # if hasattr(model, "lm"):
    #     model.lm._init_lora()

    load_pretrained_lora(cfg, model)

    # # Compile model with PyTorch 2.0
    # if torch.__version__ >= "2.0.0" and cfg.TRAIN.STAGE == "vae":
    #     logger.warning("Compiling model with PyTorch 2")
    #     model = torch.compile(model, mode="reduce-overhead")

    # Start training
    if cfg.TRAIN.RESUME:
        trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.TRAIN.PRETRAINED)
    else:
        trainer.fit(model, datamodule=datamodule)

    # Training ends
    logger.info(f"The outputs of this experiment are stored in {cfg.FOLDER_EXP}")
    logger.info("Training ends!")


if __name__ == "__main__":
    main()
