import os

from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.cli import LightningCLI

from sed.models.callbacks.sed_logging import (SparseCaloImageLogger,
                                              SparseImageLogger,
                                              SparseScrnaLogger)
from sed.models.callbacks.weight_averaging import EMAWeightAveraging


import os  # Needed for directory operations like os.makedirs


def link_arguments(parser):
    # Adds an argument related to 'unet_name' to the parser with a default value and help description
    parser.add_argument(
        "-u_name",
        "--unet_name",
        type=str,
        default="default_name",
        help="name of experiment",
    )


def add_train_args(self, parser):
    # Adds training-related arguments to the parser:
    # - experiment name
    # - debug flag
    # - sample frequency during training
    parser.add_argument(
        "-n",
        "--name",
        type=str,
        default="default_name",
        help="name of experiment",
    )
    parser.add_argument(
        "-d",
        "--debug",
        type=bool,
        default=False,
        help="debug flag",
    )
    parser.add_argument(
        "-se",
        "--sample_every",
        type=int,
        default=10000,
        help="when to sample",
    )


class SedCLI(LightningCLI):
    # Custom command-line interface class extending LightningCLI (from pytorch_lightning)

    def add_arguments_to_parser(self, parser):
        # Add training arguments using the custom method add_train_args
        self.add_train_args(parser)
        # Link some arguments between the data and model configs to keep them synchronized
        parser.link_arguments("data.init_args.input_mode",
                              "model.init_args.input_mode")
        parser.link_arguments("data.init_args.image_size",
                              "model.init_args.image_size")
        # Add UNet name argument
        link_arguments(parser)

    def before_instantiate_classes(self):
        # Modify the config name before creating model and datamodule instances
        # Appends additional identifying info extracted from a path in the model config to the name
        self.config.name = self.config.name + "_" + \
            self.config.model.init_args.vae_dir.split(
                "/")[-3].split("_")[1] + "_" + self.config.unet_name

        # Call parent method to proceed with usual setup steps
        super().before_instantiate_classes()

        # Adjust validation check frequency based on sample_every setting
        self.config.trainer.val_check_interval = self.config.sample_every
        self.config.trainer.check_val_every_n_epoch = None

        # Create a directory for saving reconstructed outputs inside the run log directory
        self.reconstr_dir = os.path.join(self.logdir, "reconstructed")
        os.makedirs(self.reconstr_dir, exist_ok=True)

    def after_instantiate_classes(self):
        # Add EMA (Exponential Moving Average) callback for model weight averaging during training
        ema_callback = EMAWeightAveraging(10, 0.9999)
        self.trainer.callbacks.append(ema_callback)
        self.trainer.ema_callback = ema_callback

        # Depending on input mode of the model's VAE, pick appropriate logger to save intermediate outputs
        if self.model.vae.input_mode == 'image':
            intermediate_logger = SparseImageLogger(
                batch_size=self.datamodule.batch_size,
                sample_every=self.config.sample_every,
                sampled_dir=self.sampled_dir)
        elif self.model.vae.input_mode == 'scrna':
            intermediate_logger = SparseScrnaLogger(
                batch_size=self.datamodule.batch_size,
                sample_every=self.config.sample_every,
                sampled_dir=self.sampled_dir)
        elif self.model.vae.input_mode == 'calo_image':
            intermediate_logger = SparseCaloImageLogger(
                batch_size=self.datamodule.batch_size,
                sample_every=self.config.sample_every,
                sampled_dir=self.sampled_dir)
        self.trainer.callbacks.append(intermediate_logger)

        # Add learning rate monitor callback to track LR over training steps
        lr_monitor = LearningRateMonitor(logging_interval='step')
        self.trainer.callbacks.append(lr_monitor)

def cli_main():
    # Entry point for running the CLI: create SedCLI instance and launch training
    cli = SedCLI(run=False)
    cli.trainer.fit(cli.model, cli.datamodule)

if __name__ == "__main__":
    # When script is run directly, execute cli_main to start training process
    cli_main()
