import argparse
import re
import shutil
from pathlib import Path

def extract_policies(path: str, out_directory: str):
    """Takes in a directory that expects the following structure:

    path/
        {agent_id}/
            params_{i}_{reward}.pt

    And parses it into 3 checkpoints for each agent_id:
    - early (10% of best possible reward)
    - middle (50%)
    - late (best reward)

    It stores it into a new directory of checkpoints

    out_directory/
        population/
            {agent_id}_early.pt
            {agent_id}_middle.pt
            {agent_id}_late.pt

    For all agent ids.

    Args:
        path (str): Path to the parameters.
        out_directory (str): Where to store the population.
    """
    path = Path(path)
    out_dir = Path(out_directory) / path.parts[-1].split("___")[-1]
    out_dir.mkdir(parents=True, exist_ok=True)
    print(out_dir)

    for agent_dir in path.iterdir():
        if not agent_dir.is_dir():
            continue

        agent_id = agent_dir.name
        checkpoints = []

        # Match filenames like "params_12_456.7.pt"
        for file in agent_dir.glob("params_*.pt"):
            match = re.match(r"params_(\d+)_(\d+(?:\.\d+)?)\.pt", file.name)
            if match:
                step = int(match.group(1))
                reward = float(match.group(2))
                checkpoints.append((file, reward, step))

        if not checkpoints:
            continue

        # Sort by reward (ascending)
        checkpoints.sort(key=lambda x: x[1])
        rewards = [reward for _, reward, _ in checkpoints]
        best_reward = rewards[-1]

        def find_closest(target_ratio: float) -> Path:
            target = target_ratio * best_reward
            closest = min(checkpoints, key=lambda x: abs(x[1] - target))
            return closest[0]

        early_ckpt = find_closest(0.10)
        middle_ckpt = find_closest(0.50)
        late_ckpt = find_closest(1.00)

        shutil.copy(early_ckpt, out_dir / f"{agent_id}_early.pt")
        shutil.copy(middle_ckpt, out_dir / f"{agent_id}_middle.pt")
        shutil.copy(late_ckpt, out_dir / f"{agent_id}_late.pt")



def main():
    parser = argparse.ArgumentParser(description="Script to parse policies into population.")
    parser.add_argument("path", type=Path, help="Path to directory")    
    parser.add_argument("out", type=str, help="Where to store")    
    args = parser.parse_args()

    if not args.path.exists():
        print(f"Error: The path '{args.path}' does not exist.")
    else:
        print(f"Received path: {args.path}")

    return extract_policies(args.path, args.out)


if __name__ == '__main__':
    main()
