# file: user_extensions/baselines/main.py
import argparse
import os
os.environ["MKL_THREADING_LAYER"] = "GNU"

from pathlib import Path
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import pytorch_lightning as pl
import torch
import torch.multiprocessing as mp
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from prism.callbacks.data_gatherer import DataGathererCallback
from prism.core.registry import DATASETS, SYSTEMS
from prism.utils.config import load_config, process_derived_config
from user_extensions.baselines.fader_networks.callbacks import BaselineEvaluationCallback

import user_extensions


def run_training(args, cli_overrides):
    print("--- Running Baseline Training Pipeline ---")

    config = load_config(args.config, cli_overrides)
    config = process_derived_config(config)

    logger = TensorBoardLogger(
        save_dir=config.run.log_dir,
        name=config.run.sweep_name,
    )
    log_dir = Path(logger.log_dir)

    datamodule = DATASETS.get(config.data.name)(config)

    system = SYSTEMS.get(config.system.name)(config)
    print(f"  > Successfully instantiated system: '{config.system.name}'")

    checkpoint_callback = ModelCheckpoint(
        dirpath=log_dir / "checkpoints",
        filename='baseline-epoch{epoch:03d}',
        monitor='val/recon_loss',
        mode='min',
        save_last=True,
        auto_insert_metric_name=False
    )

    callbacks = [
        DataGathererCallback(config),
        BaselineEvaluationCallback(config),
        checkpoint_callback
    ]

    trainer = pl.Trainer(
        max_epochs=config.training.epochs,
        logger=logger,
        callbacks=callbacks,
        log_every_n_steps=config.evaluation.log_interval,
        accelerator="auto",
        devices="auto",
        strategy='ddp_find_unused_parameters_true' if torch.cuda.device_count() > 1 else 'auto'
    )

    trainer.fit(model=system, datamodule=datamodule)

    print("\n--- Running Final Evaluation on Test Set ---")
    trainer.test(datamodule=datamodule)

    print("\n--- Baseline Training Pipeline Complete ---")


def main_cli():
    parser = argparse.ArgumentParser(
        description="PRISM Baselines - A runner for training alternative models."
    )
    parser.add_argument(
        "--config",
        type=str,
        required=True,
        help="Path to the main experiment config file for the baseline."
    )
    args, cli_overrides = parser.parse_known_args()
    run_training(args, cli_overrides)


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    torch.set_float32_matmul_precision('high')
    main_cli()