import shutil
import time
from pathlib import Path
from typing import Optional

import torch
from safetensors.torch import load_file as safe_load_file

from dae.utils.generic_utils import ensure_path


def is_training_state_dir(path):
    if not path:
        return False
    path = Path(path)
    return path.is_dir() and (path / "model.safetensors").exists() and (path / "optimizer.bin").exists()


def find_checkpoint(path, log_dir=None, ensure=True, training_state=False, model_name=None):
    if path:
        assert isinstance(path, str), f"Checkpoint path must be a string, got {path}"

        # If it's a directory, find a training state
        if Path(path).is_dir():
            check_dirs = [".", "checkpoints/last", "checkpoints/best", "last", "best", "model-checkpoint"]
            for d in check_dirs:
                if is_training_state_dir(Path(path) / d):
                    path = str(Path(path) / d)
                    break
            else:
                if ensure:
                    raise FileNotFoundError(f"Can't find any training state inside {path}")

    # Ensure that if we are looking for a training state, we only return a directory
    if training_state and not is_training_state_dir(path):
        path = None

    # Ensure that if we are looking for a specific model checkpoint, we return the right file
    if model_name and path and Path(path).is_dir():
        model_names = [f"model_{model_name}_ema.safetensors", f"model_{model_name}.safetensors"]
        for mm in model_names:
            if (Path(path) / mm).exists():
                path = str(Path(path) / mm)
                break
        else:
            path = None

    # Ensure that the path exists
    if not path or not Path(path).exists():
        if ensure:
            raise FileNotFoundError(f"Can't find checkpoint {path}")
        return None

    return path


def force_rename(src, dst):
    for i in range(10):
        try:
            if Path(dst).exists():
                if Path(dst).is_dir():
                    shutil.rmtree(dst, ignore_errors=True)
                else:
                    Path(dst).unlink()
            Path(src).rename(dst)
        except Exception as e:
            if i == 9:
                raise e
            time.sleep(1)
        else:
            break


def save_training_state(state, path, allow_backup=True):
    if path.exists() and allow_backup:
        new_path = path.with_name(f"{path.name}.new")
        back_path = path.with_name(f"{path.name}.back")
        save_training_state(state, new_path, allow_backup=False)
        with state.accelerator.sync_ctx() as main:
            if main:
                # path.rename(back_path)
                # new_path.rename(path)
                force_rename(path, back_path)
                force_rename(new_path, path)
                shutil.rmtree(back_path, ignore_errors=True)

        return

    with state.accelerator.sync_ctx() as main:
        path = Path(path)
        if main:
            ensure_path(path)

    with state.accelerator.sync_ctx():
        state.accelerator.save_state(str(path))

    # Create symbolic links for each model
    with state.accelerator.sync_ctx() as main:
        if main:
            for i, model_name in enumerate(state.registered_models):
                sd_name = "model.safetensors" if i == 0 else f"model_{i}.safetensors"
                sym_path = path / f"model_{model_name}.safetensors"
                if sym_path.exists() or sym_path.is_symlink():
                    sym_path.unlink()
                assert (path / sd_name).exists(), f"Model checkpoint {path / sd_name} does not exist"
                sym_path.symlink_to(sd_name)


def load_training_state(state, state_paths):
    # 1. Find an existing training state
    state_path = None
    for p in state_paths:
        state_path = find_checkpoint(p, state.cfg.log_dir, ensure=False, training_state=True)
        if state_path and is_training_state_dir(state_path):
            break
    else:
        state.accelerator.print(f"No training state found to resume from inside paths [{state_paths}]")
        return False

    # 2: Load training state
    state.accelerator.print(f"Resuming training from state {state_path}")

    for r in range(10):
        try:
            state.accelerator.load_state(str(state_path), strict=True, load_kwargs=dict(weights_only=False))
        except Exception as e:
            if r == 9:
                raise e
        else:
            break

    return True


def load_resume_training_metadata(cfg) -> Optional[dict]:
    for p in [cfg.checkpoint_path, cfg.model.get("checkpoint", None)]:
        state_path = find_checkpoint(p, cfg.log_dir, ensure=False, training_state=True)
        if state_path:
            return torch.load(Path(state_path) / "custom_checkpoint_0.pkl")


def load_checkpoint(accelerator, model: torch.nn.Module, ckpt_path: str, verbose=True, ckpt_map=None, strict=True, log_dir=None, model_name=None):
    assert isinstance(ckpt_path, str)

    # 1: Get path to a single-file checkpoint
    ckpt_path = find_checkpoint(ckpt_path, log_dir, model_name=model_name)
    accelerator.print(f"Loading {model.__class__.__name__} checkpoint from {ckpt_path}", show=verbose)
    accelerator.wait_for_everyone()

    # 2. Load weights from checkpoint
    ckpt_path = Path(ckpt_path)
    for k in range(10):
        try:
            if ckpt_path.suffix == ".safetensors":
                state_dict = safe_load_file(ckpt_path, device="cpu")
            elif ckpt_path.suffix in [".pt", ".bin", ".pth"]:
                state_dict = torch.load(str(ckpt_path), map_location=torch.device("cpu"), weights_only=False)
            else:
                raise ValueError(f"Unsupported file extension for weights: {ckpt_path.suffix} (trying to load from {ckpt_path})")
        except FileNotFoundError as e:
            if k == 9:
                raise e
            accelerator.print(f"Failed to load checkpoint from {ckpt_path}, retrying in 10s...", show=verbose)
            time.sleep(10)
        else:
            break

    # 3. Map weights, remove wrappers
    if ckpt_map is not None:
        state_dict = ckpt_map(state_dict)

    # 4. Load state dict into model
    load_result = model.load_state_dict(state_dict, strict=strict)
    if strict:
        assert not load_result.missing_keys and not load_result.unexpected_keys
    else:
        if load_result.missing_keys:
            accelerator.print(f"Missing keys in checkpoint: {load_result.missing_keys}", show=verbose)
        if load_result.unexpected_keys:
            accelerator.print(f"Unexpected keys in checkpoint: {load_result.unexpected_keys}", show=verbose)
    return True
