import os
import shutil
from fire import Fire


def get_all_subfolders(base_dir):
    # Initialize an empty list to store the subfolders
    subfolders = []

    # Walk through the directory structure
    for root, dirs, files in os.walk(base_dir):
        # Check if the current directory is a subfolder
        if root != base_dir:
            # Add the subfolder to the list
            subfolders.append(root)

    return subfolders


def gather_agent_files(base_dir, output_dir, env_name, dim, max_files=10, algo=None):
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Initialize a counter for the files
    file_count = 0

    # Construct the prefix to search for
    prefix = f"{env_name}_{dim}_"

    # Walk through the directory structure
    for folder in get_all_subfolders(base_dir):
        # Check if the current directory name starts with the right prefix
        if prefix in folder and file_count < max_files:
            # Look for 'agent.pth' in folder
            if "agent.pth" in os.listdir(folder) and "agent_trained.pth" in os.listdir(
                folder
            ):
                # Construct the full path to 'agent.pth'
                adversary_file_path = os.path.join(folder, "agent.pth")
                agent_file_path = os.path.join(folder, "agent_trained.pth")
                # Construct the path for the output file
                output_adversary_file_path = os.path.join(
                    output_dir, f"adversary_{file_count}_{env_name}_{dim}.pth"
                )
                output_agent_file_path = os.path.join(
                    output_dir, f"agent_{file_count}_{env_name}_{dim}.pth"
                )
                # Copy the file to the output directory
                shutil.copy(adversary_file_path, output_adversary_file_path)
                shutil.copy(agent_file_path, output_agent_file_path)
                # Increment the file counter
                file_count += 1
                # Stop if we reach the maximum number of files
                if file_count >= max_files:
                    break

    print(f"Completed! {file_count} files have been copied. {algo} {env_name} {dim}")


def main(
    all_logs_folder: str,
    output_folder: str,
):
    for env_name in ["Ant", "HalfCheetah", "Hopper", "Walker", "HumanoidStandup"]:
        for algo in [
            "m2td3",
            "rarl",
            "dr",
            "oracle_tc_m2td3",
            "oracle_tc_rarl",
            "stacked_tc_m2td3",
            "stacked_tc_rarl",
            "oracle_m2td3",
            "oracle_rarl",
            "vanilla_tc_rarl",
            "vanilla_tc_m2td3",
        ]:
            for dim in ["2", "3"]:
                gather_agent_files(
                    base_dir=os.path.join(all_logs_folder, algo),
                    output_dir=os.path.join(output_folder, algo, env_name, dim),
                    env_name=env_name,  # Environment name to filter folders
                    dim=dim,  # Dimension to filter folders
                    max_files=10,  # Maximum number of files to copy
                    algo=algo,
                )
        algo = "vanilla"
        dim = "0"
        gather_agent_files(
            base_dir=os.path.join(all_logs_folder, algo),
            output_dir=os.path.join(output_folder, algo, env_name, dim),
            env_name=env_name,  # Environment name to filter folders
            dim=dim,  # Dimension to filter folders
            max_files=10,  # Maximum number of files to copy
            algo=algo,
        )


if __name__ == "__main__":
    Fire(main)
