import os
import random
import shutil
import sys
import time
import warnings
from glob import glob
from pathlib import Path

import hydra
from hydra.utils import get_class
from lightning.pytorch.utilities.model_summary import ModelSummary
from omegaconf import OmegaConf, DictConfig

from impugen.scenarios import simulate_missing
from impugen.utils import setup_logger, SeedContext, rank_zero_print
from impugen.utils.eval import evaluate_runtime
from impugen.utils.io import setup_model_from_checkpoint
# Suppress warnings and enable full Hydra error messages
warnings.filterwarnings("ignore")
os.environ["HYDRA_FULL_ERROR"] = "1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


def _print_config_and_summary(cfg: DictConfig, model) -> None:
    """
    Print the active Hydra configuration and a summary of the model.

    Args:
        cfg (DictConfig): The current Hydra configuration.
        model: The model instance to summarize.
    """
    rank_zero_print(OmegaConf.to_yaml(cfg))
    rank_zero_print(ModelSummary(model))


def _initialize_model(cfg: DictConfig):
    """
    Instantiate a model from the configuration.

    Args:
        cfg (DictConfig): The Hydra configuration containing model settings.

    Returns:
        An initialized model instance.
    """
    cls = get_class(cfg.model._target_)
    return cls(cfg=cfg, **cfg.model)


def _train_model(cfg: DictConfig, model) -> None:
    """
    Train the model, optionally using a missing-data scenario.

    If a scenario target is specified, use in-sample missing
    to generate training data; otherwise, train normally.

    Args:tenso
        model: The model instance to train.
    """
    start = time.time()
    with SeedContext(cfg.seed):
        if cfg.scenario._target_ not in [None, 'None', 'none']:
            # Ensure reproducibility for scenario
            if not isinstance(cfg.scenario.random_state, int):
                cfg.scenario.random_state = cfg.seed
            cfg['missing'] = cfg.scenario  # TODO: Refactor scenario settings handling

            if cfg.skip_same_run:
                _check_run(cfg, model)

            scenario = simulate_missing(
                cfg, model._transform, drop_observed=False
            )
            model.fit(scenario)
        else:
            if cfg.skip_same_run:
                _check_run(cfg, model)
            model.fit()
    model._elapsed_time = time.time() - start
    evaluate_runtime(cfg, model, model.log_dir)


def _check_run(cfg: DictConfig, model) -> None:
    """
    Detect a previous run with an identical configuration and decide whether
    to skip or restart it.

    Config flag
    -----------
    cfg.skip_same_run : bool  (default = True)
        * True  → behave exactly as before (skip finished identical runs).
        * False → always (re)train, even if a finished identical run exists.

    Behaviour
    ---------
    1. Search every ``config.yaml`` under
       ``{root_dir}/{dataset.name}/{model.name}/version_*/``.
    2. For each found config:
        - If semantic content differs → ignore.
        - If identical:
            · If ``report.csv`` exists
              ▸ *skip* only when ``skip_same_run`` is **True**.
            · Else (unfinished run)
              ▸ delete the folder and proceed with a fresh run.
    3. If no identical config exists, continue as a brand-new run.
    """
    pattern = os.path.join(
        cfg.root_dir,
        cfg.dataset.name,
        model.name,
        "version_*",
        "config.yaml",
    )
    # normalised string for one-time comparison
    cfg_ = cfg.copy()
    del cfg_.skip_same_run
    if 'batch_mul' in cfg_.model:
        del cfg_.model.batch_mul
    current_yaml = OmegaConf.to_yaml(cfg_, resolve=True, sort_keys=True)
    skip_same_run = getattr(cfg, "skip_same_run", True)

    for cfg_file in glob(pattern):
        cfg_file = Path(cfg_file)
        loaded_cfg = OmegaConf.load(cfg_file)
        if 'batch_mul' in loaded_cfg.model:
            del loaded_cfg.model.batch_mul
        if 'skip_same_run' in loaded_cfg:
            del loaded_cfg.skip_same_run
        loaded_yaml = OmegaConf.to_yaml(loaded_cfg, resolve=True, sort_keys=True)

        if loaded_yaml != current_yaml:
            continue  # different experiment

        exp_dir = cfg_file.parent
        report_file = exp_dir / "report.csv"

        if report_file.is_file():
            if skip_same_run:
                print(
                    f"[SKIP] {exp_dir} – identical config already finished. "
                    "Set 'skip_same_run=false' to train with the same configuration."
                )
                sys.exit(0)  # graceful exit
            else:
                print(
                    f"[RETRAIN] {exp_dir} – identical finished run found, "
                    "but 'skip_same_run=false'; retraining will start."
                )
                # fall through → caller will create a new version_* folder
        else:
            print(f"[CLEANUP] {exp_dir} – identical config unfinished. Deleting folder.")
            shutil.rmtree(exp_dir, ignore_errors=True)
            # one unfinished identical run is enough to clean; break to proceed
            break


@hydra.main(version_base=None, config_path="impugen/configs", config_name="run")
def main(cfg: DictConfig) -> None:
    """
    Main entrypoint: train and evaluate the model based on Hydra config.

    Workflow:
        1. Unlock config for dynamic updates and set RNG seed.
        2. Load model from checkpoint or initialize new model.
        3. Print config and model summary.
        4. Train model with or without missing-data scenario if ckpt is not given.
        5. Move model to GPU and evaluate on three tasks.

    Args:
        cfg (DictConfig): The Hydra configuration object.
    """
    # Allow dynamic updates to cfg
    OmegaConf.set_struct(cfg, False)

    # Set random seed if not provided
    if not isinstance(cfg.seed, int):
        cfg.seed = random.randint(0, 2 ** 32 - 1)

    # 1. Load from checkpoint if available
    if os.path.isfile(cfg.ckpt):
        model, _ = setup_model_from_checkpoint(cfg.ckpt)
        logger = setup_logger(
            model, cfg, log_dir=os.path.dirname(cfg.ckpt), name='evaluation'
        )
        model.log_dir = logger.log_dir
        if cfg.scenario.name == 'generation' and os.path.isfile(os.path.join(os.path.dirname(cfg.ckpt), 'config.yaml')):
            _cfg = OmegaConf.load(os.path.join(os.path.dirname(cfg.ckpt), 'config.yaml'))
            cfg['scenario'] = _cfg['scenario']
        _print_config_and_summary(cfg, model)
        os.makedirs(model.log_dir, exist_ok=True)
        print(model.log_dir)
        OmegaConf.save(config=cfg, f=os.path.join(model.log_dir, 'config.yaml'))

    # 2. Otherwise, initialize a new model if target is specified
    elif cfg.model._target_ not in [None, 'None', 'none']:
        model = _initialize_model(cfg)
        _train_model(cfg, model)

    # 3. If neither checkpoint nor model target is valid, exit
    else:
        print('Valid checkpoint path or model target must be provided.')
        return

    model = model.to(cfg.device)

    # Run evaluations
    model.evaluation(cfg)


if __name__ == "__main__":
    main()
