# file: main.py
import argparse
import os
import re

os.environ["MKL_THREADING_LAYER"] = "GNU"
import sys
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
from pathlib import Path
from types import SimpleNamespace

import pytorch_lightning as pl
import torch
import torch.multiprocessing as mp
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.tensorboard import SummaryWriter

from prism.callbacks.data_gatherer import DataGathererCallback
from prism.core.registry import CALLBACKS, DATASETS, METRICS, MODELS, SYSTEMS, VISUALIZATIONS
from prism.utils.config import load_config, process_derived_config

# This import populates the registries (DATASETS, etc.) before they are used.
import user_extensions


def find_version_dir(log_dir: Path, sweep_name: str, version_to_find: int) -> Path | None:
    exp_dir = log_dir / sweep_name
    if not exp_dir.is_dir():
        return None

    if version_to_find >= 0:
        specific_version_dir = exp_dir / f"version_{version_to_find}"
        return specific_version_dir if specific_version_dir.is_dir() else None

    version_dirs = [d for d in exp_dir.iterdir() if re.match(r"version_\d+", d.name)]
    if not version_dirs:
        return None

    latest_version = max(version_dirs, key=lambda d: int(d.name.split('_')[1]))
    return latest_version


def find_last_checkpoint(version_dir: Path) -> str:
    ckpt_path = version_dir / "checkpoints" / "last.ckpt"
    if ckpt_path.exists() and ckpt_path.is_file():
        print(f"Found checkpoint to resume from: {ckpt_path}")
        return str(ckpt_path)
    raise FileNotFoundError(f"Checkpoint 'last.ckpt' not found in {version_dir / 'checkpoints'}")


def run_training(args, cli_overrides):
    print("--- Running PRISM Training Pipeline ---")

    config = load_config(args.config, cli_overrides)
    config = process_derived_config(config)

    resume_from_checkpoint = None
    version_to_log = None

    if args.resume is not None:
        base_log_dir = Path(config.run.log_dir)
        sweep_name = config.run.sweep_name

        version_dir_to_resume = find_version_dir(base_log_dir, sweep_name, args.resume)

        version_str = "latest" if args.resume == -1 else f"version_{args.resume}"
        if not version_dir_to_resume:
            print(f"Error: Cannot resume. {version_str} directory not found in '{base_log_dir / sweep_name}'.", file=sys.stderr)
            sys.exit(1)

        try:
            resume_from_checkpoint = find_last_checkpoint(version_dir_to_resume)
            version_to_log = int(version_dir_to_resume.name.split('_')[1])
        except FileNotFoundError as e:
            print(f"Error: Cannot resume. {e}", file=sys.stderr)
            sys.exit(1)

    logger = TensorBoardLogger(
        save_dir=config.run.log_dir,
        name=config.run.sweep_name,
        version=version_to_log
    )
    log_dir = Path(logger.log_dir)

    datamodule = DATASETS.get(config.data.name)(config)
    system = SYSTEMS.get("PrismSystem")(config)

    callbacks = [DataGathererCallback(config)] + [CALLBACKS.get(name)(config) for name, cb_config in config.get('callbacks', {}).items() if cb_config.get('enabled', True)]
    checkpoint_callback = ModelCheckpoint(
        dirpath=log_dir / "checkpoints",
        filename='prism-epoch{epoch:03d}-vloss{val/class_loss:.2f}',
        monitor='val/class_loss',
        mode='min',
        save_last=True,
        auto_insert_metric_name=False
    )
    callbacks.append(checkpoint_callback)

    trainer = pl.Trainer(
        max_epochs=config.training.epochs,
        logger=logger,
        callbacks=callbacks,
        log_every_n_steps=config.evaluation.log_interval,
        accelerator="auto",
        devices="auto",
        strategy='ddp_find_unused_parameters_true' if torch.cuda.device_count() > 1 else 'auto'
    )

    trainer.fit(model=system, datamodule=datamodule, ckpt_path=resume_from_checkpoint)
    trainer.test(model=system, datamodule=datamodule)
    print("\n--- PRISM Training Pipeline Complete ---")


def _setup_post_hoc_context(run_dir: Path, device_str: str):
    hparams_path = run_dir / "hparams.yaml"
    if not hparams_path.exists():
        print(f"Error: hparams.yaml not found in {run_dir}. Cannot proceed.", file=sys.stderr)
        return None, None, None

    device = torch.device(device_str)
    config = load_config(hparams_path)
    config = process_derived_config(config)

    datamodule = DATASETS.get(config.data.name)(config)
    datamodule.prepare_data()
    datamodule.setup('test')

    return config, datamodule, device


def _find_artifact_dirs(run_dir: Path):
    artifacts_path = run_dir / "artifacts"
    if not artifacts_path.is_dir():
        print(f"Warning: No 'artifacts' directory found in {run_dir}. Skipping.")
        return []

    epoch_dirs = sorted(
        [d for d in artifacts_path.iterdir() if d.is_dir() and d.name.startswith("epoch_")],
        key=lambda x: int(x.name.split('_')[-1])
    )
    test_artifacts_dir = artifacts_path / "test_set_results"
    if test_artifacts_dir.is_dir():
        epoch_dirs.append(test_artifacts_dir)

    if not epoch_dirs:
        print(f"Warning: No artifact directories found in {artifacts_path}. Nothing to process.")

    return epoch_dirs


def _load_data_artifacts(artifact_dir: Path):
    try:
        z_full = torch.load(artifact_dir / "Z_full.pt", map_location='cpu', weights_only=True)
        y_targets = torch.load(artifact_dir / "Y_targets.pt", map_location='cpu', weights_only=True)
        y_style_path = artifact_dir / "Y_style.pt"
        y_style = torch.load(y_style_path, map_location='cpu', weights_only=True) if y_style_path.exists() else None
        return z_full, y_targets, y_style
    except FileNotFoundError:
        print(f"    Warning: Could not load all data artifacts from {artifact_dir}. Skipping.")
        return None, None, None


def run_evaluation(run_dir: Path, device: str = "cpu"):
    print(f"--- Running Post-Hoc Evaluation for Run: {run_dir} ---")
    config, datamodule, device = _setup_post_hoc_context(run_dir, device)
    if not config:
        return

    writer = SummaryWriter(log_dir=str(run_dir))
    artifact_dirs = _find_artifact_dirs(run_dir)
    if not artifact_dirs:
        writer.close()
        return

    for artifact_dir in artifact_dirs:
        epoch = int(artifact_dir.name.split('_')[-1]) if "epoch" in artifact_dir.name else -1
        print(f"  > Evaluating artifacts from: {artifact_dir.name}")

        z_full, y_targets, y_style = _load_data_artifacts(artifact_dir)
        if z_full is None:
            continue

        kwargs = {
            "z_full": z_full, "y_targets": y_targets, "y_style": y_style,
            "style_feature_map": datamodule.style_feature_map, "datamodule": datamodule,
            "device": device, "config": config, "artifact_dir": artifact_dir
        }

        for name, metric_cfg in config.evaluation.metrics.items():
            if metric_cfg.get("mode") == "post-hoc":
                print(f"    - Calculating metric: {name}")
                try:
                    metric_cls = METRICS.get(name)
                    metric_instance = metric_cls(config)
                    result = metric_instance.calculate(**kwargs)

                    if isinstance(result, dict):
                        for sub_name, sub_value in result.items():
                            writer.add_scalar(f'eval_post-hoc/{name}_{sub_name}', sub_value, epoch)
                    else:
                        writer.add_scalar(f'eval_post-hoc/{name}', result, epoch)
                except Exception as e:
                    print(f"      [ERROR] Failed to calculate metric '{name}': {e}", file=sys.stderr)

    writer.close()
    print(f"--- Evaluation complete. Results logged to TensorBoard in '{run_dir}'. ---")


def run_visualization(run_dir: Path, device: str = "cpu"):
    print(f"--- Running Post-Hoc Visualization for Run: {run_dir} ---")
    config, datamodule, device = _setup_post_hoc_context(run_dir, device)
    if not config:
        return

    artifact_dirs = _find_artifact_dirs(run_dir)
    if not artifact_dirs:
        return

    latest_artifact_dir = artifact_dirs[-1]
    print(f"  > Using artifacts from: {latest_artifact_dir.name}")
    plot_dir = run_dir / "post_hoc_visualizations"
    plot_dir.mkdir(exist_ok=True)

    try:
        encoder = MODELS.get("Encoder")(config).to(device)
        generator = MODELS.get("Generator")(config).to(device)
        classifier = MODELS.get("Classifier")(config).to(device)
        encoder.load_state_dict(torch.load(latest_artifact_dir / "encoder.pth", map_location=device, weights_only=True))
        generator.load_state_dict(torch.load(latest_artifact_dir / "generator.pth", map_location=device, weights_only=True))
        classifier.load_state_dict(torch.load(latest_artifact_dir / "classifier.pth", map_location=device, weights_only=True))
        encoder.eval(), generator.eval(), classifier.eval()

    except FileNotFoundError as e:
        print(f"    Error: Could not load required model artifacts from {latest_artifact_dir}: {e}", file=sys.stderr)
        return

    z_full, y_targets, y_style = _load_data_artifacts(latest_artifact_dir)
    if z_full is None:
        return

    print("  > Loading corresponding data from test dataloader...")
    test_loader = datamodule.test_dataloader()
    num_samples_to_fetch = z_full.shape[0]

    data_batches = []
    label_batches = []

    for batch in test_loader:
        data_batches.append(batch[0])
        label_batches.append(batch[1])
        if sum(b.shape[0] for b in data_batches) >= num_samples_to_fetch:
            break

    full_data = torch.cat(data_batches, dim=0)[:num_samples_to_fetch]

    pl_module_mock = SimpleNamespace(
        device=device, encoder=encoder, generator=generator, classifier=classifier,
        config=config
    )
    trainer_mock = SimpleNamespace(datamodule=datamodule, logger=SimpleNamespace(log_dir=run_dir.parent.parent))

    kwargs = {"z_full": z_full, "y_targets": y_targets, "y_style": y_style, "data": full_data}

    for name, viz_cfg in config.evaluation.visualizations.items():
        if viz_cfg.get("mode") == "post-hoc":
            print(f"    - Generating visualization: {name}")
            try:
                epoch_num_for_viz = -1
                if "epoch" in latest_artifact_dir.name:
                    try:
                        epoch_num_for_viz = int(latest_artifact_dir.name.split('_')[-1])
                    except (ValueError, IndexError):
                        pass

                viz_cls = VISUALIZATIONS.get(name)
                viz_instance = viz_cls(config)
                viz_instance.run(trainer=trainer_mock, pl_module=pl_module_mock, plot_dir=plot_dir, epoch=epoch_num_for_viz, **kwargs)
            except Exception as e:
                print(f"      [ERROR] Failed to generate visualization '{name}': {e}", file=sys.stderr)

    print(f"--- Visualizations saved to '{plot_dir}'. ---")


def main_cli():
    parser = argparse.ArgumentParser(description="PRISM - Prototype-Regulated Identity-Style Model")
    subparsers = parser.add_subparsers(dest="command", required=True, help="Available commands")

    # Parent parser for training commands that need a config file
    config_parser = argparse.ArgumentParser(add_help=False)
    config_parser.add_argument("--config", type=str, required=True, help="Path to the main experiment config file.")

    # Parent parser for post-hoc commands that need a run directory
    post_hoc_parser = argparse.ArgumentParser(add_help=False)
    post_hoc_parser.add_argument("--run_dir", type=Path, required=True, help="Path to a specific run's version directory (e.g., runs/exp/version_0).")
    post_hoc_parser.add_argument("--device", type=str, default="cpu", help="Device to use for evaluation (e.g., 'cuda:0', 'cpu').")

    # User-facing 'train' command
    train_parser = subparsers.add_parser("train", help="Run the training pipeline.", parents=[config_parser])

    # Updated --resume argument
    train_parser.add_argument(
        "--resume",
        nargs='?',
        const=-1,
        type=int,
        default=None,
        help="Resume training. Provide a version number (e.g., --resume 0) or use without a number to resume the latest version."
    )

    # Post-hoc commands
    subparsers.add_parser("evaluate", help="Run post-hoc evaluation on a trained model.", parents=[post_hoc_parser])
    subparsers.add_parser("visualize", help="Generate post-hoc visualizations for a trained model.", parents=[post_hoc_parser])

    args, cli_overrides = parser.parse_known_args()

    if args.command == "train":
        run_training(args, cli_overrides)

    elif args.command == "evaluate":
        run_evaluation(args.run_dir, args.device)

    elif args.command == "visualize":
        run_visualization(args.run_dir, args.device)


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    torch.set_float32_matmul_precision('high')
    main_cli()