#!/usr/bin/env python3
"""Resume helpers to keep controller lightweight."""

import os
import json
import numpy as np
import glob
from typing import List, Dict, Tuple, Optional


def recover_missing_strategies(controller, state_dir: str, solver_ids: List[str], generator_ids: List[str], pools_data: Dict) -> None:
    """Recover missing solvers/generators listed by checkpoint but absent from pools.json.

    - Reads EoH pops under solver_eoh/results/pops and generator_eoh/results/pops
    - Adds recovered entries into controller.pools and controller.meta
    """
    # Solvers
    solver_map = {s["program_id"]: s for s in pools_data.get("solvers", [])}
    missing_solver_ids = [sid for sid in solver_ids if sid not in solver_map]
    if missing_solver_ids:
        codes: List[str] = []
        try:
            pops_dir = os.path.join(state_dir, "solver_eoh", "results", "pops")
            if os.path.isdir(pops_dir):
                pop_files = [f for f in os.listdir(pops_dir) if f.startswith("population_generation_") and f.endswith(".json")]
                pop_files.sort(key=lambda x: int(x.split("_")[2].split(".")[0]), reverse=True)
                for fname in pop_files:
                    with open(os.path.join(pops_dir, fname), 'r') as pf:
                        try:
                            pop = json.load(pf)
                            for ind in pop:
                                code = ind.get('code')
                                if code and code not in codes:
                                    codes.append(code)
                                    if len(codes) >= len(missing_solver_ids):
                                        break
                        except Exception:
                            continue
                    if len(codes) >= len(missing_solver_ids):
                        break
        except Exception:
            codes = []

        for sid, code in zip(missing_solver_ids, codes):
            controller.pools.add_solver(sid, code, "Recovered from EoH", {}, {"source": "resume_recovered"})
        if len(codes) < len(missing_solver_ids):
            print(f"   Warning: Only recovered {len(codes)}/{len(missing_solver_ids)} missing solvers from EoH results")

    # Generators
    generator_map = {g["program_id"]: g for g in pools_data.get("generators", [])}
    missing_generator_ids = [gid for gid in generator_ids if gid not in generator_map]
    if missing_generator_ids:
        cfgs: List[Dict] = []
        try:
            pops_dir = os.path.join(state_dir, "generator_eoh", "results", "pops")
            if os.path.isdir(pops_dir):
                pop_files = [f for f in os.listdir(pops_dir) if f.startswith("population_generation_") and f.endswith(".json")]
                pop_files.sort(key=lambda x: int(x.split("_")[2].split(".")[0]), reverse=True)
                for fname in pop_files:
                    with open(os.path.join(pops_dir, fname), 'r') as pf:
                        try:
                            pop = json.load(pf)
                            for ind in pop:
                                code = ind.get('code')
                                if code:
                                    # Build params dict with problem-specific fields
                                    params = {}
                                    if hasattr(controller, 'n_cities') and controller.n_cities is not None:
                                        params["n_cities"] = controller.n_cities
                                    if hasattr(controller.cfg, 'bp_capacity'):
                                        params["bp_capacity"] = controller.cfg.bp_capacity
                                    if hasattr(controller.cfg, 'bp_num_items'):
                                        params["bp_num_items"] = controller.cfg.bp_num_items
                                    
                                    gen_data = (code, params)
                                    if gen_data not in cfgs:
                                        cfgs.append(gen_data)
                                        if len(cfgs) >= len(missing_generator_ids):
                                            break
                        except Exception:
                            continue
                    if len(cfgs) >= len(missing_generator_ids):
                        break
        except Exception:
            cfgs = []

        for gid, gen_data in zip(missing_generator_ids, cfgs):
            code, params = gen_data
            controller.pools.add_generator(gid, code, "Recovered generator", params, {"source": "resume_recovered"})
        if len(cfgs) < len(missing_generator_ids):
            print(f"  Warning: Only recovered {len(cfgs)}/{len(missing_generator_ids)} missing generators from EoH results")


def load_utilities_with_fallback(controller, state_dir: str, solver_ids: List[str], generator_ids: List[str], utils_path: str) -> None:
    """Load utilities from NPY; fall back to CSV or zeros if needed.
    
    IMPORTANT: The utility matrix rows must correspond to solver_ids order,
    and columns must correspond to generator_ids order.
    
    This function ensures that the utility matrix is reordered to match the
    solver_ids and generator_ids order from the checkpoint, even if pools.json
    has a different order.
    """
    n_solvers_checkpoint = len(solver_ids)
    n_generators_checkpoint = len(generator_ids)

    # Load pools.json to get the order in which utility matrix was saved
    pools_path = os.path.join(state_dir, "psro_results", "pools.json")
    pools_solver_order = []
    pools_generator_order = []
    
    if os.path.exists(pools_path):
        with open(pools_path, "r", encoding="utf-8") as f:
            pools_data = json.load(f)
        pools_solver_order = [s["program_id"] for s in pools_data.get("solvers", [])]
        pools_generator_order = [g["program_id"] for g in pools_data.get("generators", [])]
    
    if os.path.exists(utils_path):
        full_utilities = np.load(utils_path)
        if full_utilities.shape[0] >= n_solvers_checkpoint and full_utilities.shape[1] >= n_generators_checkpoint:
            # Extract submatrix matching checkpoint dimensions
            utilities_saved = full_utilities[:n_solvers_checkpoint, :n_generators_checkpoint]
            
            # Reorder utility matrix to match solver_ids and generator_ids order
            # if pools.json order differs from checkpoint order
            if pools_solver_order and pools_generator_order:
                # Check if reordering is needed - compare lists properly
                solver_order_matches = (pools_solver_order[:n_solvers_checkpoint] == solver_ids[:n_solvers_checkpoint] 
                                       if len(pools_solver_order) >= n_solvers_checkpoint 
                                       else False)
                generator_order_matches = (pools_generator_order[:n_generators_checkpoint] == generator_ids[:n_generators_checkpoint]
                                          if len(pools_generator_order) >= n_generators_checkpoint
                                          else False)
                
                if not solver_order_matches or not generator_order_matches:
                    print(f"   Reordering utility matrix to match checkpoint order...")
                    # Create index mapping: saved_idx -> checkpoint_idx
                    # For each solver_id in checkpoint order, find its position in saved order
                    solver_saved_to_checkpoint = {}
                    for checkpoint_idx, solver_id in enumerate(solver_ids):
                        if solver_id in pools_solver_order[:n_solvers_checkpoint]:
                            saved_idx = pools_solver_order[:n_solvers_checkpoint].index(solver_id)
                            solver_saved_to_checkpoint[saved_idx] = checkpoint_idx
                    
                    generator_saved_to_checkpoint = {}
                    for checkpoint_idx, generator_id in enumerate(generator_ids):
                        if generator_id in pools_generator_order[:n_generators_checkpoint]:
                            saved_idx = pools_generator_order[:n_generators_checkpoint].index(generator_id)
                            generator_saved_to_checkpoint[saved_idx] = checkpoint_idx
                    
                    # Reorder rows and columns: from saved order to checkpoint order
                    reordered_utilities = np.zeros_like(utilities_saved)
                    for saved_i, checkpoint_i in solver_saved_to_checkpoint.items():
                        for saved_j, checkpoint_j in generator_saved_to_checkpoint.items():
                            reordered_utilities[checkpoint_i, checkpoint_j] = utilities_saved[saved_i, saved_j]
                    
                    controller.meta.utilities = reordered_utilities
                    print(f"   Loaded and reordered utilities matrix: {controller.meta.utilities.shape}")
                    return
                else:
                    # Order matches, use directly
                    controller.meta.utilities = utilities_saved
                    print(f"   Loaded utilities matrix: {controller.meta.utilities.shape}")
                    return
            else:
                # No pools.json or empty, assume order matches
                controller.meta.utilities = utilities_saved
                print(f"   Loaded utilities matrix: {controller.meta.utilities.shape}")
                return
        print(f"   Warning: Utilities matrix shape {full_utilities.shape} doesn't match checkpoint state")

    # Try CSV
    csv_path = os.path.join(state_dir, "psro_results", "utility_matrix.csv")
    if os.path.exists(csv_path):
        try:
            rows = []
            with open(csv_path, 'r') as cf:
                for line in cf:
                    line = line.strip()
                    if not line:
                        continue
                    parts = [p for p in line.replace('\t', ',').split(',') if p != '']
                    if len(parts) == 1:
                        rows.append([float(parts[0])])
                    else:
                        rows.append([float(x) for x in parts])
            csv_mat = np.array(rows, dtype=float)
            if csv_mat.ndim == 1:
                csv_mat = csv_mat.reshape(1, -1)
            h = min(csv_mat.shape[0], n_solvers_checkpoint)
            g = min(csv_mat.shape[1], n_generators_checkpoint)
            utilities_saved = np.zeros((n_solvers_checkpoint, n_generators_checkpoint), dtype=float)
            utilities_saved[:h, :g] = csv_mat[:h, :g]
            
            # Reorder utility matrix to match solver_ids and generator_ids order
            # if pools.json order differs from checkpoint order
            if pools_solver_order and pools_generator_order:
                solver_order_matches = pools_solver_order[:n_solvers_checkpoint] == solver_ids[:n_solvers_checkpoint]
                generator_order_matches = pools_generator_order[:n_generators_checkpoint] == generator_ids[:n_generators_checkpoint]
                
                if not solver_order_matches or not generator_order_matches:
                    print(f"   Reordering utility matrix from CSV to match checkpoint order...")
                    # Create index mapping: checkpoint_idx -> saved_idx
                    solver_saved_to_checkpoint = {}
                    for checkpoint_idx, solver_id in enumerate(solver_ids):
                        if solver_id in pools_solver_order[:n_solvers_checkpoint]:
                            saved_idx = pools_solver_order[:n_solvers_checkpoint].index(solver_id)
                            solver_saved_to_checkpoint[saved_idx] = checkpoint_idx
                    
                    generator_saved_to_checkpoint = {}
                    for checkpoint_idx, generator_id in enumerate(generator_ids):
                        if generator_id in pools_generator_order[:n_generators_checkpoint]:
                            saved_idx = pools_generator_order[:n_generators_checkpoint].index(generator_id)
                            generator_saved_to_checkpoint[saved_idx] = checkpoint_idx
                    
                    # Reorder rows and columns
                    reordered_utilities = np.zeros_like(utilities_saved)
                    for saved_i, checkpoint_i in solver_saved_to_checkpoint.items():
                        for saved_j, checkpoint_j in generator_saved_to_checkpoint.items():
                            reordered_utilities[checkpoint_i, checkpoint_j] = utilities_saved[saved_i, saved_j]
                    
                    controller.meta.utilities = reordered_utilities
                    print(f"   Loaded and reordered utilities matrix from CSV: {controller.meta.utilities.shape}")
                    return
                else:
                    # Order matches, use directly
                    controller.meta.utilities = utilities_saved
                    print(f"   Loaded utilities matrix from CSV: {controller.meta.utilities.shape}")
                    return
            else:
                # No pools.json or empty, assume order matches
                controller.meta.utilities = utilities_saved
                print(f"   Loaded utilities matrix from CSV: {controller.meta.utilities.shape}")
                return
        except Exception as e:
            print(f"   Warning: Failed to load utility_matrix.csv ({e}), initializing empty matrix")

    # Fallback to zeros
    controller.meta.utilities = np.zeros((n_solvers_checkpoint, n_generators_checkpoint))
    print(f"   Warning: utilities.npy not found, initializing empty matrix")


def find_latest_checkpoint(psro_results_dir: str) -> Optional[Tuple[int, str]]:
    """
    Find the latest checkpoint file.
    
    Args:
        psro_results_dir: Path to psro_results directory
        
    Returns:
        Tuple of (latest_round_number, checkpoint_filename) or None if no checkpoint found
    """
    if not os.path.exists(psro_results_dir):
        return None
    
    checkpoint_files = []
    for f in os.listdir(psro_results_dir):
        if f.startswith("checkpoint_round_") and f.endswith(".json"):
            try:
                round_num = int(f.split("_")[2].split(".")[0])
                checkpoint_files.append((round_num, f))
            except (IndexError, ValueError):
                continue
    
    if checkpoint_files:
        latest_round, latest_file = max(checkpoint_files, key=lambda x: x[0])
        return latest_round, latest_file
    return None


def load_checkpoint_data(checkpoint_path: str) -> Dict:
    """
    Load checkpoint data from JSON file.
    
    Args:
        checkpoint_path: Path to checkpoint JSON file
        
    Returns:
        Dictionary containing checkpoint data
        
    Raises:
        FileNotFoundError: If checkpoint file doesn't exist
        json.JSONDecodeError: If file is not valid JSON
    """
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    
    with open(checkpoint_path, "r") as f:
        return json.load(f)


def restore_pools_from_checkpoint(
    controller,
    state_dir: str,
    solver_ids: List[str],
    generator_ids: List[str]
) -> None:
    """
    Restore pools from checkpoint state.
    
    Args:
        controller: Controller instance to restore state to
        state_dir: State directory path
        solver_ids: List of solver IDs from checkpoint (these are the ONLY ones we want)
        generator_ids: List of generator IDs from checkpoint (these are the ONLY ones we want)
    """
    pools_path = os.path.join(state_dir, "psro_results", "pools.json")
    
    if not os.path.exists(pools_path):
        print(f"   Error: pools.json not found at {pools_path}")
        return
    
    with open(pools_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # Build maps for quick lookup
    solver_map = {s["program_id"]: s for s in data.get("solvers", [])}
    generator_map = {g["program_id"]: g for g in data.get("generators", [])}
    
    # Verify pools.json doesn't have extra strategies (warning only - we'll filter anyway)
    all_solver_ids = set(solver_map.keys())
    all_generator_ids = set(generator_map.keys())
    target_solver_ids = set(solver_ids)
    target_generator_ids = set(generator_ids)
    
    extra_solvers = all_solver_ids - target_solver_ids
    extra_generators = all_generator_ids - target_generator_ids
    
    if extra_solvers or extra_generators:
        print(f"   Warning: pools.json contains strategies not in checkpoint:")
        if extra_solvers:
            print(f"      Extra solvers: {sorted(extra_solvers)}")
        if extra_generators:
            print(f"      Extra generators: {sorted(extra_generators)}")
        print(f"      These will be filtered out - pools.json should have been cleaned earlier")
    
    # Restore solvers - ONLY those in solver_ids
    for solver_id in solver_ids:
        if solver_id in solver_map:
            s = solver_map[solver_id]
            params = s.get("params", {})
            controller.pools.add_solver(
                s["program_id"], 
                s["code"], 
                s.get("algorithm", "Unknown"), 
                params,
                s.get("metadata", {})
            )
        else:
            print(f"   Warning: Solver {solver_id} not found in pools.json")
    
    # Restore generators - ONLY those in generator_ids
    for generator_id in generator_ids:
        if generator_id in generator_map:
            g = generator_map[generator_id]
            if "config" in g:
                old_config = g["config"]
                code = old_config.get("code", "")
                params = {k: v for k, v in old_config.items() 
                         if k not in ["code", "algorithm", "type"]}
            else:
                code = g.get("code", "")
                params = g.get("params", {})
            
            controller.pools.add_generator(
                g["program_id"], 
                code,
                g.get("algorithm", "Unknown"), 
                params,
                g.get("metadata", {})
            )
        else:
            print(f"   Warning: Generator {generator_id} not found in pools.json")
    
    # Verify we restored the correct number
    if controller.pools.n_solvers != len(solver_ids):
        print(f"    Warning: Restored {controller.pools.n_solvers} solvers, expected {len(solver_ids)}")
    if controller.pools.n_generators != len(generator_ids):
        print(f"    Warning: Restored {controller.pools.n_generators} generators, expected {len(generator_ids)}")
    
    # Delegate advanced recovery and utilities fallback to helpers
    utils_path = os.path.join(state_dir, "psro_results", "utilities.npy")
    recover_missing_strategies(controller, state_dir, solver_ids, generator_ids, data)
    load_utilities_with_fallback(controller, state_dir, solver_ids, generator_ids, utils_path)


def clean_files_after_round(state_dir: str, target_round: int) -> None:
    """
    Clean all files and data from rounds after target_round.
    
    This ensures that when resuming from a specific round, all data from
    later rounds is removed to prevent dimension mismatches and inconsistencies.
    
    Args:
        state_dir: State directory path
        target_round: Target round number (keep files up to and including this round)
    """
    psro_results_dir = os.path.join(state_dir, "psro_results")
    if not os.path.exists(psro_results_dir):
        return
    
    print(f"   Cleaning files from rounds after {target_round}...")
    
    files_cleaned = 0
    
    # 1. Clean checkpoint files
    checkpoint_pattern = os.path.join(psro_results_dir, "checkpoint_round_*.json")
    for checkpoint_file in glob.glob(checkpoint_pattern):
        try:
            round_num = int(os.path.basename(checkpoint_file).split("_")[2].split(".")[0])
            if round_num > target_round:
                os.remove(checkpoint_file)
                files_cleaned += 1
                print(f"     Removed {os.path.basename(checkpoint_file)}")
        except (ValueError, IndexError):
            continue
    
    # 2. Clean nash_mixture round files
    nash_pattern = os.path.join(psro_results_dir, "nash_mixture_round_*.json")
    for nash_file in glob.glob(nash_pattern):
        try:
            round_num = int(os.path.basename(nash_file).split("_")[2].split(".")[0])
            if round_num > target_round:
                os.remove(nash_file)
                files_cleaned += 1
        except (ValueError, IndexError):
            continue
    
    # 3. Clean nash_mixture_history.json (remove entries after target_round)
    history_file = os.path.join(psro_results_dir, "nash_mixture_history.json")
    if os.path.exists(history_file):
        try:
            with open(history_file, 'r') as f:
                history_data = json.load(f)
            original_len = len(history_data)
            history_data = [entry for entry in history_data if entry.get("iteration", 0) <= target_round]
            if len(history_data) < original_len:
                with open(history_file, 'w') as f:
                    json.dump(history_data, f, indent=2, ensure_ascii=False)
                files_cleaned += 1
                print(f"     Cleaned nash_mixture_history.json (removed {original_len - len(history_data)} entries)")
        except Exception as e:
            print(f"     Warning: Failed to clean nash_mixture_history.json: {e}")
    
    # 4. Clean testing files (if any)
    testing_pattern = os.path.join(psro_results_dir, "round_*_testing.json")
    for testing_file in glob.glob(testing_pattern):
        try:
            round_num = int(os.path.basename(testing_file).split("_")[1])
            if round_num > target_round:
                os.remove(testing_file)
                files_cleaned += 1
        except (ValueError, IndexError):
            continue
    
    # 5. Clean visualization files (if any)
    viz_dir = os.path.join(state_dir, "visualizations")
    if os.path.exists(viz_dir):
        viz_patterns = [
            os.path.join(viz_dir, "solver_visualization_round_*.png"),
            os.path.join(viz_dir, "generator_visualization_round_*.png"),
            os.path.join(viz_dir, "round_*_*.png"),
        ]
        for pattern in viz_patterns:
            for viz_file in glob.glob(pattern):
                try:
                    # Try to extract round number from filename
                    basename = os.path.basename(viz_file)
                    parts = basename.split("_")
                    for part in parts:
                        if part.startswith("round"):
                            round_str = part.replace("round", "").split(".")[0]
                            if round_str.isdigit():
                                round_num = int(round_str)
                                if round_num > target_round:
                                    os.remove(viz_file)
                                    files_cleaned += 1
                                    break
                except (ValueError, IndexError):
                    continue
    
    # 6. Clean timing statistics files (if any)
    timing_pattern = os.path.join(psro_results_dir, "timing_round_*.json")
    for timing_file in glob.glob(timing_pattern):
        try:
            round_num = int(os.path.basename(timing_file).split("_")[2].split(".")[0])
            if round_num > target_round:
                os.remove(timing_file)
                files_cleaned += 1
        except (ValueError, IndexError):
            continue
    
    # 7. Clean pools.json - remove strategies added after target_round
    checkpoint_path = os.path.join(psro_results_dir, f"checkpoint_round_{target_round}.json")
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                checkpoint_data = json.load(f)
            
            pools_state = checkpoint_data.get("pools_state", {})
            target_solver_ids = set(pools_state.get("solver_ids", []))
            target_generator_ids = set(pools_state.get("generator_ids", []))
            
            pools_path = os.path.join(psro_results_dir, "pools.json")
            if os.path.exists(pools_path):
                with open(pools_path, 'r', encoding="utf-8") as f:
                    pools_data = json.load(f)
                
                original_solver_count = len(pools_data.get("solvers", []))
                original_generator_count = len(pools_data.get("generators", []))
                
                # Filter solvers
                filtered_solvers = [s for s in pools_data.get("solvers", []) 
                                  if s.get("program_id") in target_solver_ids]
                
                # Filter generators
                filtered_generators = [g for g in pools_data.get("generators", []) 
                                     if g.get("program_id") in target_generator_ids]
                
                removed_solvers = original_solver_count - len(filtered_solvers)
                removed_generators = original_generator_count - len(filtered_generators)
                
                if removed_solvers > 0 or removed_generators > 0:
                    pools_data["solvers"] = filtered_solvers
                    pools_data["generators"] = filtered_generators
                    
                    with open(pools_path, 'w', encoding="utf-8") as f:
                        json.dump(pools_data, f, indent=2, ensure_ascii=False)
                    
                    files_cleaned += 1
                    print(f"     Cleaned pools.json: removed {removed_solvers} solvers, {removed_generators} generators")
                    print(f"      Kept: {len(filtered_solvers)} solvers, {len(filtered_generators)} generators")
        except Exception as e:
            print(f"     Warning: Failed to clean pools.json: {e}")
    
    # 8. Clean utilities.npy - trim matrix to match target strategies
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                checkpoint_data = json.load(f)
            
            pools_state = checkpoint_data.get("pools_state", {})
            target_n_solvers = pools_state.get("n_solvers", None)
            target_n_generators = pools_state.get("n_generators", None)
            
            utils_path = os.path.join(psro_results_dir, "utilities.npy")
            if os.path.exists(utils_path) and target_n_solvers is not None and target_n_generators is not None:
                utilities = np.load(utils_path)
                original_shape = utilities.shape
                
                # Trim utilities matrix to match target dimensions
                if utilities.shape[0] > target_n_solvers or utilities.shape[1] > target_n_generators:
                    trimmed_utilities = utilities[:target_n_solvers, :target_n_generators]
                    np.save(utils_path, trimmed_utilities)
                    files_cleaned += 1
                    print(f"     Cleaned utilities.npy: {original_shape} -> {trimmed_utilities.shape}")
        except Exception as e:
            print(f"     Warning: Failed to clean utilities.npy: {e}")
    
    if files_cleaned > 0:
        print(f"    Cleaned {files_cleaned} file(s) from rounds after {target_round}")
    else:
        print(f"    No files found to clean (all rounds <= {target_round})")


def clean_strategies_after_round(controller, state_dir: str, target_round: int) -> None:
    """
    Clean strategies (solvers/generators) that were added after target_round.
    
    This ensures pools only contain strategies up to target_round.
    
    Args:
        controller: Controller instance (will be modified)
        state_dir: State directory path
        target_round: Target round number
    """
    # Load checkpoint to see what strategies should exist at target_round
    psro_results_dir = os.path.join(state_dir, "psro_results")
    checkpoint_path = os.path.join(psro_results_dir, f"checkpoint_round_{target_round}.json")
    
    if not os.path.exists(checkpoint_path):
        return
    
    try:
        with open(checkpoint_path, 'r') as f:
            checkpoint_data = json.load(f)
        
        pools_state = checkpoint_data.get("pools_state", {})
        target_solver_ids = set(pools_state.get("solver_ids", []))
        target_generator_ids = set(pools_state.get("generator_ids", []))
        
        # Check if current pools have extra strategies
        current_solver_ids = {s.program_id for s in controller.pools.solver_pool}
        current_generator_ids = {g.program_id for g in controller.pools.generator_pool}
        
        extra_solvers = current_solver_ids - target_solver_ids
        extra_generators = current_generator_ids - target_generator_ids
        
        if extra_solvers or extra_generators:
            print(f"    Removing strategies added after round {target_round}...")
            print(f"    Extra solvers: {extra_solvers}")
            print(f"    Extra generators: {extra_generators}")
            
            # Note: We don't actually remove from pools here because restore_pools_from_checkpoint
            # will rebuild pools from checkpoint data. But we can verify utilities matrix matches.
            print(f"   Strategy cleanup handled by restore_pools_from_checkpoint")
    except Exception as e:
        print(f"    Warning: Failed to check strategy cleanup: {e}")


def resume_from_checkpoint(controller, state_dir: str, resume_from_round: Optional[int] = None) -> Dict:
    """
    Complete resume logic: find checkpoint, load data, restore pools.
    
    Args:
        controller: Controller instance to resume
        state_dir: State directory path
        resume_from_round: Optional specific round number to resume from
        
    Returns:
        Dictionary with resume state information:
        {
            "iteration": int,
            "eoh_state": {
                "solver_eoh_generation": int,
                "generator_eoh_generation": int
            }
        }
    """
    # Determine which checkpoint to load
    psro_results_dir = os.path.join(state_dir, "psro_results")
    
    if resume_from_round is not None:
        checkpoint_to_load = f"checkpoint_round_{resume_from_round}.json"
    else:
        # Find the latest checkpoint
        result = find_latest_checkpoint(psro_results_dir)
        if result is None:
            print(f"    Warning: No checkpoint files found, starting from iteration 0")
            return {
                "iteration": 0,
                "eoh_state": {
                    "solver_eoh_generation": 0,
                    "generator_eoh_generation": 0
                }
            }
        _, checkpoint_to_load = result
    
    # Load checkpoint data
    checkpoint_path = os.path.join(psro_results_dir, checkpoint_to_load)
    try:
        checkpoint_data = load_checkpoint_data(checkpoint_path)
    except Exception as e:
        print(f"    Error loading checkpoint {checkpoint_to_load}: {e}")
        return {
            "iteration": 0,
            "eoh_state": {
                "solver_eoh_generation": 0,
                "generator_eoh_generation": 0
            }
        }
    
    # Extract state information
    iteration = checkpoint_data.get("iteration", 0)
    eoh_state = checkpoint_data.get("eoh_state", {})
    pools_state = checkpoint_data.get("pools_state", {})
    solver_ids = pools_state.get("solver_ids", [])
    generator_ids = pools_state.get("generator_ids", [])
    
    target_round = iteration
    if resume_from_round is not None:
        target_round = resume_from_round
        if iteration != resume_from_round:
            print(f"    Warning: Checkpoint iteration ({iteration}) != resume_from_round ({resume_from_round})")
            print(f"      Using resume_from_round ({resume_from_round}) as target iteration")
        iteration = resume_from_round
    
    print(f"    Loading checkpoint: {checkpoint_to_load}")
    print(f"    Target state: iteration={iteration}, solvers={len(solver_ids)}, generators={len(generator_ids)}")
    print(f"    EOH states: solver_gen={eoh_state.get('solver_eoh_generation', 0)}, generator_gen={eoh_state.get('generator_eoh_generation', 0)}")
    
    clean_files_after_round(state_dir, target_round)
    
    # Restore pools
    restore_pools_from_checkpoint(controller, state_dir, solver_ids, generator_ids)
    
    return {
        "iteration": iteration,
        "eoh_state": eoh_state
    }

