# Copyright (c) 2026 Anonymous
# All Rights Reserved
# This codebase is provided for peer review purposes only.

import os
import torch
import shutil
import functools
import multiprocessing


def timeout_decorator(seconds):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            def target(queue, *args, **kwargs):
                try:
                    result = func(*args, **kwargs)
                    queue.put((True, result))  # Success case
                except Exception as e:
                    queue.put((False, e))  # Capture exception
            queue = multiprocessing.Queue()
            process = multiprocessing.Process(
                target=target,
                args=(queue, *args),
                kwargs=kwargs
            )
            print("Starting an external process for model saving.")
            process.daemon = True
            process.start()
            print("Starting to wait for the process to end, with a time-out.")
            process.join(seconds)
            if process.is_alive():
                print("Time-out occurred! Attempting to terminate the process.")
                process.kill()
                print("Waiting for the process to end.")
                process.join()
                raise TimeoutError(f"Function '{func.__name__}' timed out.")
            else:
                print("Time-out did not occur.")
                if queue.empty():
                    raise Exception("Unexpected error: Subprocess exited without posting status")
                success, payload = queue.get()
                if success:
                    return payload
                else:
                    raise payload  # Re-raise the captured exception
        return wrapper
    return decorator


@timeout_decorator(seconds=300)
def process_iter_folders(vault_path, idx_iter):
    # Define `iter_folder_all` and sort it
    iter_folder_all = os.listdir(os.path.join(vault_path, "checkpoints"))
    iter_folder_all = sorted(iter_folder_all, key=lambda x: int(x))
    while len(iter_folder_all) > 1:
        iter_folder_current = iter_folder_all.pop(0)
        shutil.rmtree(os.path.join(vault_path, "checkpoints", iter_folder_current))
    # Create the new iter folder
    os.makedirs(os.path.join(vault_path, "checkpoints", str(idx_iter)))


@timeout_decorator(seconds=300)
def save_checkpoint(vault_path, checkpoint_dict, idx_iter, rank):
    torch.save(checkpoint_dict, os.path.join(vault_path, "checkpoints", str(idx_iter), f"rank_{rank}.pt"))
