import contextlib
import math
import os
import random
import subprocess
from datetime import datetime
from pathlib import Path

from omegaconf import DictConfig
import numpy as np
import torch
import wandb


def maybe_compile(fn=None, **kwargs):
    """Conditionally compile a function with torch.compile.

    By default, compilation is disabled. To enable, set environment variable:
        TORCH_COMPILE=1
        or
        TORCH_COMPILE=true

    Args:
        fn: Function to optionally compile.
        **kwargs: Arguments to pass to torch.compile (e.g., dynamic=True).

    Returns:
        Compiled function if enabled, otherwise the original function.
    """
    _enable_compile = os.getenv("TORCH_COMPILE", "0").lower() in ("1", "true")

    def _compile(func):
        return torch.compile(func, **kwargs) if _enable_compile else func

    if fn is None:
        return _compile
    return _compile(fn)


def wandb_login() -> None:
    """Log in to wandb."""

    # Check .secret file for API key
    secret_path = Path(__file__).parent.parent / ".secret"
    if secret_path.exists():
        with open(secret_path) as f:
            for line in f:
                if line.startswith("WANDB_API_KEY="):
                    api_key = line.strip().split("=", 1)[1]
                    wandb.login(key=api_key)
                    return

    # If the .secret file does not exist, prompt user to enter API key
    wandb.login()


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)


def temp_seed(seed: int | None):
    if seed is None:
        return contextlib.nullcontext()
    return _temp_seed(seed)


@contextlib.contextmanager
def _temp_seed(seed: int):
    random_state = random.getstate()
    np_state = np.random.get_state()
    torch_state = torch.get_rng_state()
    torch_cuda_states = torch.cuda.get_rng_state_all()
    set_seed(seed)

    try:
        yield
    finally:
        random.setstate(random_state)
        np.random.set_state(np_state)
        torch.set_rng_state(torch_state)
        torch.cuda.set_rng_state_all(torch_cuda_states)


def get_git_hash() -> str:
    """Get the current git hash."""
    try:
        return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
    except Exception:
        return "unknown"


def get_save_dir(
    target: DictConfig,
    algorithm: DictConfig,
    exp_name: str | None,
) -> str:
    """Return the relative save directory string based on config values.

    Calculates path based on target and algorithm settings.
    """
    target_name = target.name
    if target.name == "gmm":
        dim = target.spatial_dim
        nbits = target.n_bits
        n = target.n_centres
        var = target.variance
        target_name += f"_dim{dim}nbits{nbits}n{n}var{var}"
    elif target.name == "manywell":
        dim = target.spatial_dim
        rotated = "rotated" if target.rotated else ""
        beta = target.beta
        nbits = target.n_bits
        target_name += f"_dim{dim}{rotated}b{beta}nbits{nbits}"
    elif target.name == "ising":
        L = target.ising_L
        beta = target.ising_beta
        J = target.ising_J
        target_name += f"L{L}beta{beta}J{J}"
    elif target.name == "potts":
        L = target.potts_L
        q = target.potts_q
        beta = target.potts_beta
        J = target.potts_J
        target_name += f"L{L}q{q}beta{beta}J{J}"
    else:
        raise ValueError(f"Unknown target: {target.name}")

    parts = []
    if exp_name:
        parts.append(exp_name)

    parts.append(algorithm.name)
    parts.append(datetime.now().strftime("%Y%m%d_%H%M%S"))
    run_name = "_".join(parts)

    return os.path.join(os.getcwd(), "results", target_name, run_name)


def to_binary(x: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray:
    """Convert {-1, +1} to {0, 1}."""
    return (x + 1) // 2


def to_spin(x: torch.Tensor | np.ndarray) -> torch.Tensor | np.ndarray:
    """Convert {0, 1} to {-1, +1}."""
    return 2 * x - 1


def linear_annealing(
    current: int,
    n_rounds: int,
    min_val: float,
    max_val: float,
    descending=False,
    log=False,
    avoid_zero=False,
) -> float:
    assert min_val <= max_val
    if min_val == max_val:
        return min_val

    start_val, end_val = min_val, max_val
    if descending:
        start_val, end_val = end_val, start_val

    if current >= n_rounds:
        return end_val

    num = current + 1 if avoid_zero else current
    denom = n_rounds + 1 if avoid_zero else n_rounds
    multiplier = math.log(num) / math.log(denom) if log else num / denom
    return start_val + (end_val - start_val) * multiplier
