"""
Adapted from original train.py script in https://github.com/Stability-AI/stable-audio-tools
"""
from prefigure.prefigure import get_all_args, push_wandb_config
import json
import os
import torch
import pytorch_lightning as pl
import random

from stable_audio_tools.data.dataset import create_dataloader_from_config
from stable_audio_tools.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
from stable_audio_tools.training import create_training_wrapper_from_config, create_demo_callback_from_config
from stable_audio_tools.training.utils import copy_state_dict
from stable_audio_tools.models import create_model_from_config
from stable_audio_tools.data.dataset import collation_fn
from stable_audio_tools.data.dataset_edit import EditDataset



class ExceptionCallback(pl.Callback):
    def on_exception(self, trainer, module, err):
        print(f'{type(err).__name__}: {err}')


class ModelConfigEmbedderCallback(pl.Callback):
    def __init__(self, model_config):
        self.model_config = model_config

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        checkpoint["model_config"] = self.model_config


torch.set_float32_matmul_precision("high")

def main():
    args = get_all_args()

    seed = args.seed

    # Set a different seed for each process if using SLURM
    if os.environ.get("SLURM_PROCID") is not None:
        seed += int(os.environ.get("SLURM_PROCID"))

    random.seed(seed)
    torch.manual_seed(seed)

    # Get JSON config from args.model_config
    with open(args.model_config) as f:
        model_config = json.load(f)

    with open(args.dataset_config) as f:
        dataset_config = json.load(f)

    train_set = EditDataset(
        datasets=dataset_config["datasets"],
        sample_size=model_config["sample_size"],
        sample_rate=model_config["sample_rate"],
        random_crop=dataset_config.get("random_crop", False),
        force_channels="stereo"
    )
    demo_set = EditDataset(
        datasets=dataset_config["demo_datasets"],
        sample_size=model_config["sample_size"],
        sample_rate=model_config["sample_rate"],
        random_crop=dataset_config.get("random_crop", False),
        force_channels="stereo"
    )

    train_dl = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        persistent_workers=True,
        pin_memory=True,
        drop_last=True,
        collate_fn=collation_fn
    )
    demo_dl = torch.utils.data.DataLoader(
        dataset=demo_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=1,
        persistent_workers=True,
        pin_memory=True,
        drop_last=False,
        collate_fn=collation_fn
    )

    model = create_model_from_config(model_config)

    if args.pretrained_ckpt_path:
        copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path))

    if args.remove_pretransform_weight_norm == "pre_load":
        remove_weight_norm_from_model(model.pretransform)

    if args.pretransform_ckpt_path:
        print(f"Loading pretransform checkpoint from {args.pretransform_ckpt_path}")
        model.pretransform.load_state_dict(load_ckpt_state_dict(args.pretransform_ckpt_path))

    # Remove weight_norm from the pretransform if specified
    if args.remove_pretransform_weight_norm == "post_load":
        remove_weight_norm_from_model(model.pretransform)

    training_wrapper = create_training_wrapper_from_config(model_config, model)

    wandb_logger = pl.loggers.WandbLogger(project=args.name, save_dir=os.environ["WANDB_DIR"])
    wandb_logger.watch(training_wrapper)

    exc_callback = ExceptionCallback()

    if args.save_dir and isinstance(wandb_logger.experiment.id, str):
        checkpoint_dir = os.path.join(args.save_dir, wandb_logger.experiment.project, wandb_logger.experiment.id,
                                      "checkpoints")
    else:
        checkpoint_dir = None

    ckpt_callback = pl.callbacks.ModelCheckpoint(every_n_train_steps=args.checkpoint_every, dirpath=checkpoint_dir,
                                                 save_top_k=-1)
    save_model_config_callback = ModelConfigEmbedderCallback(model_config)

    demo_callback = create_demo_callback_from_config(model_config, demo_dl=demo_dl)

    # Combine args and config dicts
    args_dict = vars(args)
    args_dict.update({"model_config": model_config})
    args_dict.update({"dataset_config": dataset_config})

    push_wandb_config(wandb_logger, args_dict)

    # Set multi-GPU strategy if specified
    if args.strategy:
        if args.strategy == "deepspeed":
            from pytorch_lightning.strategies import DeepSpeedStrategy
            strategy = DeepSpeedStrategy(stage=2,
                                         contiguous_gradients=True,
                                         overlap_comm=True,
                                         reduce_scatter=True,
                                         reduce_bucket_size=5e8,
                                         allgather_bucket_size=5e8,
                                         load_full_weights=True
                                         )
        else:
            strategy = args.strategy
    else:
        strategy = 'ddp_find_unused_parameters_true' if args.num_gpus > 1 else "auto"

    trainer = pl.Trainer(
        devices=args.num_gpus,
        accelerator="gpu",
        num_nodes=args.num_nodes,
        strategy=strategy,
        precision=args.precision,
        accumulate_grad_batches=args.accum_batches,
        callbacks=[ckpt_callback, demo_callback, exc_callback, save_model_config_callback],
        logger=wandb_logger,
        log_every_n_steps=1,
        max_epochs=10000000,
        default_root_dir=args.save_dir,
        gradient_clip_val=args.gradient_clip_val,
        reload_dataloaders_every_n_epochs=0
    )

    trainer.fit(training_wrapper, train_dl, ckpt_path=args.ckpt_path if args.ckpt_path else None)



if __name__ == '__main__':
    main()