import os
import wandb
import random
import string
import subprocess
import torch
import numpy as np
import re
import hashlib
from omegaconf import OmegaConf


def get_pod_name():
    """Retrieve pod name (Kubernetes hostname)."""
    try:
        return subprocess.check_output(["hostname"]).decode("utf-8").strip()
    except subprocess.CalledProcessError as e:
        print(f"An error occurred: {e}")
        return None


def get_core_config_dict(config):
    """
    Returns only the 'core' fields for equivalence checking.
    Remove ephemeral fields like run.hash, run.out_dir, run.pod_name, etc.
    Adjust according to your actual config structure.
    """
    if config is None:
        return None
    cfg_dict = OmegaConf.to_container(config, resolve=True)
    if "run" in cfg_dict:
        cfg_dict["run"].pop("hash", None)
        cfg_dict["run"].pop("pod_name", None)
        cfg_dict["run"].pop("out_dir", None)
    return cfg_dict


def compute_stable_config_hash(config):
    """
    Produce a *deterministic* hash from the 'core' config fields.
    Then truncate/pad it to length 20 for your use.
    """
    core_cfg_dict = get_core_config_dict(config)
    # Convert to YAML with sorted keys
    yaml_str = OmegaConf.to_yaml(core_cfg_dict, sort_keys=True)
    # MD5 => 32 hex chars. We'll take the first 20 to get a 20-char stable hash.
    full_md5 = hashlib.md5(yaml_str.encode("utf-8")).hexdigest()
    stable_hash_20 = full_md5[:20]
    return stable_hash_20


def parse_step_from_ckpt_filename(ckpt_path):
    """
    For files named like 'model_1000000.pth', extract 1000000 as int.
    """
    base = os.path.basename(ckpt_path)
    match = re.search(r"model_(\d+)\.pth$", base)
    if match:
        return int(match.group(1))
    return None


def find_last_checkpoint(models_dir):
    """
    Find the checkpoint with the largest step in 'model_<step>.pth' naming.
    """
    if not os.path.exists(models_dir):
        return None
    ckpt_files = [
        f for f in os.listdir(models_dir) if f.startswith("model_") and f.endswith(".pth")
    ]
    if not ckpt_files:
        return None

    step_ckpts = []
    for ck in ckpt_files:
        ck_path = os.path.join(models_dir, ck)
        step = parse_step_from_ckpt_filename(ck_path)
        if step is not None:
            step_ckpts.append((step, ck_path))

    if not step_ckpts:
        return None

    # Largest step => newest
    step_ckpts.sort(key=lambda x: x[0], reverse=True)
    return step_ckpts[0][1]


def set_exp(config):
    """
    Creates or resumes an experiment run using a *single stable hash of length 20*
    for both the folder name and the W&B run ID.
    - If we detect the same hyperparameters, we REWIND the same run in W&B.
    - If hyperparameters differ, we create a new folder & new run.

    NOTE:
    - If you want multiple runs for identical hyperparams, they will conflict
      (since the same stable hash => same run ID in W&B).
    - If you change the 'random_seed' in your config, that will produce a new hash,
      hence a new run & folder.
    """

    # 1) Compute stable 20-char hash for this config
    if "debug" in config.run.exp_name.lower() or config.run.start_from_scratch:
        stable_hash = "".join(random.choices(string.ascii_letters + string.digits, k=20))
    else:
        stable_hash = compute_stable_config_hash(config)

    # 2) Build out/<exp_name>/<stable_hash>/ as the folder
    base_exp_dir = os.path.join(os.getcwd(), "out", config.run.exp_name)
    run_dir = os.path.join(base_exp_dir, stable_hash)
    models_dir = os.path.join(run_dir, "models")
    os.makedirs(base_exp_dir, exist_ok=True)

    # 3) If we find an old config, compare it
    old_config_path = os.path.join(run_dir, "config.yaml")
    if os.path.exists(old_config_path) and not config.run.start_from_scratch:
        old_config = OmegaConf.load(old_config_path)
    else:
        old_config = None

    old_core = get_core_config_dict(old_config)
    new_core = get_core_config_dict(config)

    # 4) We'll store the "run ID" (which is the same stable_hash) in wandb_id.txt
    if config.run.wandb_writer:
        wandb_id_file = os.path.join(run_dir, "wandb_id.txt")
    else:
        wandb_writer = None

    # 5) Check if old config matches => resume run
    if old_config is not None and old_core == new_core and not config.run.start_from_scratch:
        print(f"[set_exp] Found existing config => rewinding/resuming run {stable_hash}")

        # The W&B run ID is the same as stable_hash (since we used it originally)
        if config.run.wandb_writer:
            if not os.path.exists(wandb_id_file):
                raise ValueError(f"No wandb_id.txt found in {run_dir}, cannot resume properly.")
            with open(wandb_id_file, "r") as f:
                existing_run_id = f.read().strip()

        last_ckpt_path = find_last_checkpoint(models_dir)
        if last_ckpt_path is None:
            last_step = 0
            print("No checkpoint found => rewinding to step=0.")
        else:
            last_step = parse_step_from_ckpt_filename(last_ckpt_path) or 0
            print(f"Rewinding to checkpoint {last_ckpt_path}, step={last_step}")

        if config.run.wandb_writer:
            rewind_spec = f"{existing_run_id}?_step={last_step}"
            # Use W&B's "resume_from" param for rewinding
            try:
                wandb_writer = wandb.init(
                    entity=config.run.wandb_entity,
                    project=config.run.exp_name,
                    id=existing_run_id,
                    resume_from=rewind_spec,
                    config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
                    settings=wandb.Settings(init_timeout=120),
                )
            except wandb.errors.CommError as e:
                print("Rewind failed; falling back to resume mode.")
                wandb_writer = wandb.init(
                    entity=config.run.wandb_entity,
                    project=config.run.exp_name,
                    id=existing_run_id,
                    resume="allow",
                    config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
                    settings=wandb.Settings(init_timeout=3600),
                )
    else:
        # No match => new run
        print(f"[set_exp] Creating NEW run directory => {run_dir}")
        os.makedirs(run_dir, exist_ok=True)
        os.makedirs(models_dir, exist_ok=True)

        # Our single ID is the stable hash
        new_run_id = stable_hash
        # Store it in config, so we can see it if needed
        config.run.hash = new_run_id

        # Start a fresh W&B run with id = stable_hash
        if config.run.wandb_writer:
            wandb_writer = wandb.init(
                entity=config.run.wandb_entity,
                project=config.run.exp_name,
                name=new_run_id,  # optional 'name' for the UI
                config=OmegaConf.to_container(config, resolve=True, throw_on_missing=True),
                id=new_run_id,
                resume="never",
                save_code=False,
                settings=wandb.Settings(
                    save_code=False, _save_requirements=False, disable_code=True
                ),
            )
        else:
            wandb_writer = None

        # Save the config and the run_id
        OmegaConf.save(config, old_config_path)
        if config.run.wandb_writer:
            with open(wandb_id_file, "w") as f:
                f.write(new_run_id)

    # 6) Basic logging
    config.run.out_dir = run_dir
    print(f"Experiment folder: {run_dir}")

    pod_name = get_pod_name()
    config.run.pod_name = pod_name

    # Seeds
    torch.manual_seed(config.run.random_seed)
    torch.cuda.manual_seed_all(config.run.random_seed)
    np.random.seed(config.run.random_seed)
    random.seed(config.run.random_seed)

    if config.run.det_run:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    config.run.device = "cuda" if torch.cuda.is_available() else "cpu"

    if config.run.wandb_writer and wandb_writer is not None:

        def keep(path: str, root: str):
            rel = os.path.relpath(path, root)
            return (
                not rel.startswith("env/")
                and not rel.startswith("out/")
                and not rel.startswith("outputs/")
                and not rel.startswith("notebook/")
                and not rel.startswith("figures/")
                and not rel.startswith("wandb/")
                and not rel.startswith(".git/")
                and not rel.startswith("slurm-")
            )

        wandb_writer.log_code(
            ".",
            include_fn=keep,
        )

    print("Final config:")
    print(OmegaConf.to_yaml(config))

    # 7) Return last checkpoint path if it exists
    if os.path.exists(models_dir):
        last_ckpt_path = find_last_checkpoint(models_dir)
    else:
        last_ckpt_path = None

    return wandb_writer, last_ckpt_path
