"""
MT3 baseline training. 
To use random order, use `dataset.dataset_2_random`. Or else, use `dataset.dataset_2`.
"""

import os

from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader

import torch
import pytorch_lightning as pl
import os

import hydra


@hydra.main(config_path="config", config_name="config")
# def main(config, model_config, result_dir, mode, path):
def main(cfg):
    # set seed to ensure reproducibility
    pl.seed_everything(cfg.seed)
    cfg.model.config.use_prompt = cfg.use_prompt

    model = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim)
    model = torch.compile(model)
    def count_parameters(model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    num_params = count_parameters(model)
    print(f"Trainable parameters: {num_params:,}", flush=True)
    logger = WandbLogger(project=f"{cfg.model_type}_{cfg.dataset_type}")

    # sanity check to make sure the correct model is used
    assert cfg.model_type == cfg.model._target_.split(".")[-1]

    lr_monitor = LearningRateMonitor(logging_interval="step")

    checkpoint_callback = ModelCheckpoint(**cfg.modelcheckpoint)
    tqdm_callback = TQDMProgressBar(refresh_rate=1)

    trainer = pl.Trainer(
        logger=logger,
        callbacks=[lr_monitor, checkpoint_callback, tqdm_callback],
        **cfg.trainer,
    )
    print("Loading data...", flush=True)
    print("Train data:", cfg.dataset.train, flush=True)

    train_loader = DataLoader(
        hydra.utils.instantiate(cfg.dataset.train),
        **cfg.dataloader.train,
        collate_fn=hydra.utils.get_method(cfg.dataset.collate_fn),
    )

    val_loader = DataLoader(
        hydra.utils.instantiate(cfg.dataset.val),
        **cfg.dataloader.val,
        collate_fn=hydra.utils.get_method(cfg.dataset.collate_fn),
    )
    # Dynamically calculate num_steps_per_epoch
    dataset_size = len(train_loader.dataset)
    batch_size = train_loader.batch_size
    num_steps_per_epoch = (dataset_size + batch_size - 1) // batch_size
    print(f"Calculated num_steps_per_epoch: {num_steps_per_epoch}")
    
    # Pass num_steps_per_epoch to the model configuration
    model.num_steps_per_epoch = num_steps_per_epoch

    # Handle checkpoint loading based on the lightweight flag
    if cfg.path is not None and cfg.path != "":

        if cfg.use_lightweight_checkpoint:
            print("Creating and using a lightweight checkpoint...")
            checkpoint = torch.load(cfg.path, map_location="cpu")

            # Extract only the model's weights
            state_dict = checkpoint.get("state_dict", checkpoint)
            lightweight_ckpt_path = cfg.path.replace(".ckpt", "_lightweight.ckpt")

            # Save the lightweight checkpoint (state_dict only)
            torch.save({"state_dict": state_dict}, lightweight_ckpt_path)
            print(f"Saved lightweight checkpoint to {lightweight_ckpt_path}")

            # Load the model weights manually
            model.load_state_dict(state_dict)

            # Start fresh training
            trainer.fit(model, train_loader, val_loader, ckpt_path=None)

        elif cfg.path.endswith(".ckpt"):
            print(f"Validating and resuming training on {cfg.path}...")
            trainer.validate(model, val_loader, ckpt_path=cfg.path)
            trainer.fit(model, train_loader, val_loader, ckpt_path=cfg.path)

        elif cfg.path.endswith(".pth"):
            print(f"Loading weights from {cfg.path}...")
            checkpoint = torch.load(cfg.path, map_location="cpu")
            state_dict = checkpoint.get("model_state_dict", checkpoint)

            # Adjust state_dict keys if saved with DDP wrapper
            new_state_dict = {
                k.replace("module.", ""): v for k, v in state_dict.items()
            }
            missing_keys, unexpected_keys = model.model.load_state_dict(
                new_state_dict, strict=False
            )
            print("Missing keys:", missing_keys, flush=True)
            print("Unexpected keys:", unexpected_keys, flush=True)

            trainer.fit(model, train_loader, val_loader)

        else:
            raise ValueError(f"Invalid extension for path: {cfg.path}")
    else:
        trainer.fit(model, train_loader, val_loader)

    # save the model in .pt format
    current_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
    ckpt_path = os.path.join(
        current_dir,
        f"{cfg.model_type}_{cfg.dataset_type}",
        "version_0/checkpoints/last.ckpt",
    )
    # make sure the ckpt_path exists
    if not os.path.exists(ckpt_path):
        # if not, create the directory
        os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)

    model.eval()
    dic = {}
    for key in model.state_dict():
        if "model." in key:
            dic[key.replace("model.", "")] = model.state_dict()[key]
        else:
            dic[key] = model.state_dict()[key]
    torch.save(dic, ckpt_path.replace(".ckpt", ".pt"))
    print(f"Saved model in {ckpt_path.replace('.ckpt', '.pt')}.")


if __name__ == "__main__":

    main()
