#!/usr/bin/env python3
"""
run.py - Unified entrypoint for training and testing pipeline

Usage:
    python run.py --config config.yaml [--mode train|test]

If --mode is omitted, both training and testing will run sequentially.
"""
import argparse
from types import SimpleNamespace
import os
os.environ["PYTORCH_FLASH_ATTENTION"] = "1"
os.environ["TIMM_FUSED_ATTN"] = "1"
# import torch


# from torch.nn.attention import SDPBackend, sdpa_kernel
# sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]).__enter__()



import yaml
from enum import Enum

# Import your training and testing routines
# Adjust these imports according to your project structure.
from unconditional.train import main as train_main
from unconditional.test import main as test_main

from utils import load_config, seed_everything, setup_experiment_dir, find_latest_checkpoint



possible_mask_choices = {
    "predict": 0,
    "backward": 1,
    "interpolation": 2,
    "unconditional": 3,
    "one_frame": 4,
    "arbitrary_interpolation": 5,
    "spatial_temporal": 6,
}

class MaskChoice(Enum):
    PREDICT = 0
    BACKWARD = 1
    INTERPOLATION = 2
    UNCONDITIONAL = 3
    ONE_FRAME = 4
    ARBITRARY_INTERPOLATION = 5
    SPATIAL_TEMPORAL = 6
    
    def __str__(self):
        return self.name.lower()
    
    @classmethod
    def from_string(cls, value):
        try:
            return cls[value.upper()]
        except KeyError:
            raise ValueError(f"Invalid mask choice: {value}. Must be one of {[e.name.lower() for e in cls]}.")
    
    @classmethod
    def from_index(cls, index):
        if not (0 <= index < len(cls)):
            raise ValueError(f"Index {index} is out of range for MaskChoice enum.")
        return list(cls)[index]


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run training and/or testing from a YAML config file."
    )
    parser.add_argument(
        "--config", "-c", default=None, help="Path to YAML config file"
    )
    parser.add_argument(
        "--mode",
        choices=["train", "test"],
        default=None,
        help="Mode to run: train, test, or both if omitted",
    )

    parser.add_argument(
        "--results_dir",
        default="exp_results",
        help="Directory to save new experiment results. Default is 'exp_results'.",
    )

    parser.add_argument(
        "--exp_dir",
        default=None,
        help="Experiment directory to load for testing.",
    )

    parser.add_argument(
        "--ckpt_num",
        type=str,
        default=None,
        help="Checkpoint number to load for testing or to resume the training from. If None, the latest checkpoint will be used for test, or a new training will be started.",
    )

    args = parser.parse_args()
    assert (
        args.mode != "test" or args.exp_dir is not None
    ), "If --mode=test, --exp_dir must be specified to load the experiment."
    assert (
        args.ckpt_num is None or args.exp_dir is not None
    ), "If --ckpt_num is specified, --exp_dir must also be specified to load the experiment."

    if args.config is None:
        assert args.exp_dir is not None and os.path.exists(os.path.join(args.exp_dir, "config.yaml")), \
            "If --config is not specified, --exp_dir must point to an existing experiment directory with a config.yaml file."
        args.config = os.path.join(args.exp_dir, "config.yaml")

    return args


def merge_args(common, specific):
    # Create a merged dict without modifying originals
    cfg = common.copy()
    cfg.update(specific or {})
    return cfg


def dict_to_ns(d):
    if not isinstance(d, dict):
        return d
    return SimpleNamespace(**{k: dict_to_ns(v) for k, v in d.items()})


def main():
    args = parse_args()
    cfg = load_config(args.config)

    common_args = cfg.get("common_args", {})
    
    # load additional arguments from command line
    common_args['results_dir'] = args.results_dir
    common_args["exp_dir"] = args.exp_dir

    try:
        common_args["ckpt_num"] = int(args.ckpt_num)
    except (ValueError, TypeError):
        # If ckpt_num is not an integer, keep it as a string
        # This allows for 'best' or other non-integer values
        common_args["ckpt_num"] = args.ckpt_num

    str_mask_choice = common_args["mask_choice"]
    # Convert string mask choice to integer index or ensure the index is valid
    if isinstance(common_args["mask_choice"], str):
        # If mask_choice is a string, convert it to an enum
        common_args["mask_choice"] = MaskChoice.from_string(common_args["mask_choice"]).value
    elif isinstance(common_args["mask_choice"], int):
        # If mask_choice is an integer, convert it to an enum
        common_args["mask_choice"] = MaskChoice.from_index(common_args["mask_choice"]).value

    # Run training if requested or if no mode specified
    if args.mode in (None, "train"):
        train_cfg = merge_args(common_args, cfg.get("train"))
        train_cfg = dict_to_ns(train_cfg)
        model_name = train_cfg.model.name
        if train_cfg.exp_dir is None:
            # Setup experiment directory and checkpoint directory
            exp_dir, checkpoint_dir = setup_experiment_dir(
                SimpleNamespace(
                    **{
                        "exp_dir_name": f"{model_name}_{train_cfg.data_path.split('/')[-1].split('.')[0]}_{train_cfg.timefreq_transform}",
                        "results_dir": os.path.join(
                            train_cfg.results_dir,
                            str_mask_choice,
                            str(train_cfg.seq_len),
                        ),
                    }
                )
            )

            # update args both for training and for using them in testing if needed
            train_cfg.exp_dir = common_args["exp_dir"] = exp_dir
            train_cfg.checkpoint_dir = common_args["checkpoint_dir"] = checkpoint_dir
        else:
            checkpoint_dir = os.path.join(train_cfg.exp_dir, "checkpoints")
            train_cfg.checkpoint_dir = common_args["checkpoint_dir"] = checkpoint_dir
            if train_cfg.ckpt_num is None:
                # If no checkpoint number is specified, find the latest checkpoint
                train_cfg.ckpt_num = find_latest_checkpoint(
                    train_cfg.checkpoint_dir,
                    return_int=True
                )
        
        seed_everything(common_args['seed'])
        with open(os.path.join(train_cfg.exp_dir, "config.yaml"), "w") as f:
            # Save the configuration used for this run
            yaml.dump(cfg, f)
        train_main(train_cfg)

    # Run testing if requested or if no mode specified
    if args.mode in (None, "test"):
        test_cfg = merge_args(common_args, cfg.get("test"))
        test_cfg = dict_to_ns(test_cfg)

        if args.mode == "test":
            # If --mode is testing, no checkpoint directory is specified yet. We need to set it based on the experiment directory.
            test_cfg.checkpoint_dir = os.path.join(test_cfg.exp_dir, "checkpoints")

        if test_cfg.ckpt_num is None:
            if os.path.exists(
                os.path.join(test_cfg.checkpoint_dir, "best.pt")
            ):
                # If a best checkpoint exists, use it
                test_cfg.ckpt_num = "best"
            else:
                # If no checkpoint number is specified, find the latest checkpoint
                test_cfg.ckpt_num = find_latest_checkpoint(
                    test_cfg.checkpoint_dir,
                    return_int=True
                )

        seed_everything(common_args['seed'])
        test_main(test_cfg)


if __name__ == "__main__":
    main()
