import os
from pathlib import Path
import subprocess
from dataclasses import dataclass, field
from typing import Any, Literal, Optional, Union

import hydra
from omegaconf import DictConfig, OmegaConf, MISSING

from vis_models.io.dirs import set_artifacts_dir, get_artifacts_dir
from vis_datasets.lib.dirs import set_dataset_dir
from vis_analysis_utils.publish.upload import (
    UploadConfig,
    register_upload_config,
    set_upload_config,
)

from experiments.transforms_vs_identity import (
    tvi_experiment,
    TvIExperimentConfig,
    register_transforms_vs_identity_configs,
    set_tvi_seeds,
)
from experiments.transforms_mismatch import (
    tm_experiment,
    TMExperimentConfig,
    register_transforms_mismatch_configs,
    set_tm_seeds,
)
from experiments.real_world_transforms_vs_identity import (
    rw_tvi_experiment,
    RWTvIExperimentConfig,
    register_real_world_transforms_vs_identities_configs,
)
from experiments.invariance_transfer import (
    it_experiment,
    ITExperimentConfig,
    register_ti_configs,
    set_it_seeds,
)
from experiments.cross_transforms import (
    ct_experiment,
    CTExperimentConfig,
    register_cross_transforms_configs,
    set_ct_seeds,
)
from experiments.representation_impact import (
    ri_experiment,
    RIExperimentConfig,
    register_ri_configs,
    set_ri_seeds,
)
from experiments.random_to_cifar import (
    r2c_experiment,
    R2CExperimentConfig,
    register_r2c_configs,
    set_r2c_seeds,
)
from experiments.irrelevant_feature_extraction import (
    ife_experiment,
    IFEExperimentConfig,
    register_ife_configs,
    set_ife_seeds,
)


@dataclass
class DirectoryConfig:
    artifacts: Union[str, Path]
    data: Union[str, Path]

ExperimentType = Literal["tvi", "ct", "rw_tvi"]

@dataclass
class Config:
    dirs: DirectoryConfig
    upload: UploadConfig
    sid: int
    tb: Optional[ExperimentType] = None
    tvi: Optional[TvIExperimentConfig] = None
    tm: Optional[TMExperimentConfig] = None
    rw_tvi: Optional[RWTvIExperimentConfig] = None
    it: Optional[ITExperimentConfig] = None
    ct: Optional[CTExperimentConfig] = None
    ri: Optional[RIExperimentConfig] = None
    r2c: Optional[R2CExperimentConfig] = None
    ife: Optional[IFEExperimentConfig] = None

register_upload_config()
register_transforms_vs_identity_configs()
register_transforms_mismatch_configs()
register_real_world_transforms_vs_identities_configs()
register_ti_configs()
register_cross_transforms_configs()
register_ri_configs()
register_r2c_configs()
register_ife_configs()


@hydra.main(
    version_base=None,
    config_path="../conf",
    config_name="config",
)
def main(cfg: Config) -> None:
    set_artifacts_dir(cfg.dirs.artifacts)
    set_dataset_dir(cfg.dirs.data)
    set_upload_config(cfg.upload)

    if hasattr(cfg, "tb") and cfg.tb is not None:
        print(f"Starting tensorboard for {cfg.tb}")
        # TODO: set subdirectory based on the seclected experiment
        subprocess.run(["tensorboard", "--logdir", get_artifacts_dir()])
    elif hasattr(cfg, "tvi") and cfg.tvi is not None:
        tvi_experiment(set_tvi_seeds(cfg.tvi, cfg.sid))
    elif hasattr(cfg, "tm") and cfg.tm is not None:
        tm_experiment(set_tm_seeds(cfg.tm, cfg.sid))
    elif hasattr(cfg, "rw_tvi") and cfg.rw_tvi is not None:
        rw_tvi_experiment(cfg.rw_tvi)
    elif hasattr(cfg, "it") and cfg.it is not None:
        it_experiment(set_it_seeds(cfg.it, cfg.sid))
    elif hasattr(cfg, "ct") and cfg.ct is not None:
        ct_experiment(set_ct_seeds(cfg.ct, cfg.sid))
    elif hasattr(cfg, "ri") and cfg.ri is not None:
        ri_experiment(set_ri_seeds(cfg.ri, cfg.sid))
    elif hasattr(cfg, "r2c") and cfg.r2c is not None:
        r2c_experiment(set_r2c_seeds(cfg.r2c, cfg.sid))
    elif hasattr(cfg, "ife") and cfg.ife is not None:
        ife_experiment(set_ife_seeds(cfg.ife, cfg.sid))
    else:
        print("No experiment config provided, starting Jupyter Lab")
        subprocess.run(
            [
                # "python", "-m",
                "poetry", "run",
                "jupyter", "lab",
                "--allow-root", "--no-browser", "--ip=0.0.0.0",
            ],
            env={
                **os.environ,
                **UploadConfig(**cfg.upload).to_env_vars(),
            }
        )

    # elif args.experiment == "tr_v_id":
    #     config = tr_v_id_configs.get_config(args.exp_config)
    #     config.seed = args.seed
    #     experiment = objects_2d_exp.Objects2DExperiment(config)
    #     experiment.run()
    # else:
    #     # Add other execution modes here
    #     raise NotImplementedError()


if __name__ == "__main__":
    # multiprocessing.set_start_method("spawn", force=True)
    main()
