#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.multiprocessing import Pool
from torch.nn import DataParallel
from torch.distributions.categorical import Categorical
from torch.distributions import Normal
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import gym
from gym.spaces import Discrete
import random
import math
import time
import warnings
from tqdm import tqdm
from sklearn import metrics
from matplotlib import pyplot as plt
from matplotlib import animation
from itertools import repeat
from scipy.interpolate import Rbf
import scipy.stats as st

# ======================== GAME THEORY IMPORTS ========================
import copy
import csv
from collections import defaultdict
from typing import List, Dict, Tuple, Set
from itertools import combinations
from scipy.optimize import minimize
import pickle

# Suppress warnings
warnings.filterwarnings("ignore", category=Warning)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# =============================================================================
# CONFIGURATION - Modify these parameters as needed
# =============================================================================
import csv
import os
import pandas as pd
import numpy as np

class PerformanceCSVSaver:
    """Enhanced CSV saver for detailed performance tracking"""
    
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        self.csv_dir = os.path.join(save_dir, 'performance_csv')
        os.makedirs(self.csv_dir, exist_ok=True)
        
    def save_strategy_performance_csv(self, strategy_name: str, memory_data: dict, 
                                    game_theory_metrics: dict = None):
        """Save detailed performance data for a strategy"""
        
        # Extract data for the first run (or combine multiple runs)
        if not memory_data['steps']:
            print(f"No data to save for strategy: {strategy_name}")
            return
            
        # Get data for first run_id (you can modify this to handle multiple runs)
        run_id = list(memory_data['steps'].keys())[0]
        steps = memory_data['steps'][run_id]
        training_values = memory_data['training_values'][run_id]
        eval_values = memory_data['eval_values'][run_id]
        
        # Prepare CSV data
        csv_data = []
        max_len = max(len(steps), len(training_values), len(eval_values))
        
        for i in range(max_len):
            row = {
                'step': steps[i] if i < len(steps) else None,
                'training_return': training_values[i] if i < len(training_values) else None,
                'validation_return': eval_values[i] if i < len(eval_values) else None,
                'strategy': strategy_name
            }
            
            # Add game theory metrics if available
            if (game_theory_metrics and run_id in game_theory_metrics and 
                i < len(game_theory_metrics[run_id])):
                metrics = game_theory_metrics[run_id][i]
                row.update({
                    'num_attackers': metrics.get('num_attackers', 0),
                    'num_defended': metrics.get('num_defended', 0),
                    'performance_drop': metrics.get('performance_drop', 0),
                    'theoretical_damage': metrics.get('theoretical_damage', 0),
                    'successful_attacks': metrics.get('successful_attacks', 0),
                    'defense_budget_util': metrics.get('defense_budget_utilization', 0),
                    'attack_budget_util': metrics.get('attack_budget_utilization', 0),
                    'avg_attacker_utility': metrics.get('avg_attacker_utility', 0),
                    'current_era': metrics.get('current_era', 0),
                    'deterrence_efficiency': metrics.get('deterrence_efficiency', 0)
                })
            
            csv_data.append(row)
        
        # Save to CSV
        csv_filename = os.path.join(self.csv_dir, f'{strategy_name}_performance.csv')
        
        if csv_data:
            df = pd.DataFrame(csv_data)
            df.to_csv(csv_filename, index=False)
            print(f"Saved performance data for {strategy_name}: {csv_filename}")
            print(f"  Rows: {len(df)}")
            print(f"  Final training return: {df['training_return'].iloc[-1]:.2f}")
            print(f"  Final validation return: {df['validation_return'].iloc[-1]:.2f}")
    
    def save_combined_performance_csv(self, all_strategy_results: dict):
        """Save all strategies in one combined CSV file"""
        combined_data = []
        
        for strategy_name, results in all_strategy_results.items():
            steps = results.get('steps', [])
            training_values = results.get('training_values', [])
            eval_values = results.get('eval_values', [])
            game_metrics = results.get('game_theory_metrics', [])
            
            max_len = max(len(steps), len(training_values), len(eval_values))
            
            for i in range(max_len):
                row = {
                    'strategy': strategy_name,
                    'step': steps[i] if i < len(steps) else np.nan,
                    'training_return': training_values[i] if i < len(training_values) else np.nan,
                    'validation_return': eval_values[i] if i < len(eval_values) else np.nan
                }
                
                # Add game theory metrics if available
                if game_metrics and i < len(game_metrics):
                    metrics = game_metrics[i]
                    row.update({
                        'num_attackers': metrics.get('num_attackers', 0),
                        'num_defended': metrics.get('num_defended', 0),
                        'performance_drop': metrics.get('performance_drop', 0),
                        'theoretical_damage': metrics.get('theoretical_damage', 0),
                        'successful_attacks': metrics.get('successful_attacks', 0),
                        'defense_budget_util': metrics.get('defense_budget_utilization', 0),
                        'attack_budget_util': metrics.get('attack_budget_utilization', 0),
                        'current_era': metrics.get('current_era', 0)
                    })
                else:
                    # Fill with zeros for non-game-theory strategies
                    row.update({
                        'num_attackers': 0,
                        'num_defended': 0,
                        'performance_drop': 0,
                        'theoretical_damage': 0,
                        'successful_attacks': 0,
                        'defense_budget_util': 0,
                        'attack_budget_util': 0,
                        'current_era': 0
                    })
                
                combined_data.append(row)
        
        # Save combined CSV
        combined_filename = os.path.join(self.csv_dir, 'all_strategies_performance.csv')
        
        if combined_data:
            df = pd.DataFrame(combined_data)
            df.to_csv(combined_filename, index=False)
            print(f"Saved combined performance data: {combined_filename}")
            print(f"  Total rows: {len(df)}")
            print(f"  Strategies: {df['strategy'].unique().tolist()}")
    
    def create_plotting_ready_csv(self, all_strategy_results: dict):
        """Create CSV optimized for plotting with matplotlib/seaborn"""
        
        # Create separate files for training and validation
        training_data = []
        validation_data = []
        
        for strategy_name, results in all_strategy_results.items():
            steps = results.get('steps', [])
            training_values = results.get('training_values', [])
            eval_values = results.get('eval_values', [])
            
            # Training data
            for step, train_val in zip(steps, training_values):
                training_data.append({
                    'step': step,
                    'return': train_val,
                    'strategy': strategy_name,
                    'type': 'training'
                })
            
            # Validation data
            for step, eval_val in zip(steps, eval_values):
                validation_data.append({
                    'step': step,
                    'return': eval_val,
                    'strategy': strategy_name,
                    'type': 'validation'
                })
        
        # Save training data
        if training_data:
            training_df = pd.DataFrame(training_data)
            training_csv = os.path.join(self.csv_dir, 'training_returns_all_strategies.csv')
            training_df.to_csv(training_csv, index=False)
            print(f"Saved training data: {training_csv}")
        
        # Save validation data  
        if validation_data:
            validation_df = pd.DataFrame(validation_data)
            validation_csv = os.path.join(self.csv_dir, 'validation_returns_all_strategies.csv')
            validation_df.to_csv(validation_csv, index=False)
            print(f"Saved validation data: {validation_csv}")
        
        # Combined file for easy plotting
        all_data = training_data + validation_data
        if all_data:
            all_df = pd.DataFrame(all_data)
            all_csv = os.path.join(self.csv_dir, 'all_returns_plotting_ready.csv')
            all_df.to_csv(all_csv, index=False)
            print(f"Saved plotting-ready data: {all_csv}")

# Modifications to your existing Agent class - add this method
def save_detailed_performance_csv(self, save_dir=None, strategy_name='default'):
    """Enhanced method to save detailed performance data"""
    if save_dir is None:
        save_dir = self.opts.save_dir if self.opts.save_dir else './results'
    
    saver = PerformanceCSVSaver(save_dir)
    
    memory_data = {
        'steps': self.memory.steps,
        'training_values': self.memory.training_values,
        'eval_values': self.memory.eval_values
    }
    
    game_metrics = getattr(self.memory, 'game_theory_metrics', None)
    
    saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)

# Enhanced strategy comparison function - replace your existing one
def run_strategy_comparison_with_csv_export(opts):
    """Strategy comparison with detailed CSV export"""
    import pprint
    pprint.pprint(vars(opts))
    
    all_strategy_results = {}
    
    # Initialize CSV saver
    csv_saver = PerformanceCSVSaver(opts.save_dir if opts.save_dir else './results')
    
    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"RUNNING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")
        
        # Configure for this strategy
        if strategy_name == 'baseline':
            opts.enable_game_theory = False
        else:
            opts.enable_game_theory = True
        
        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = strategy_name
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
            
        all_strategy_results[strategy_name] = strategy_results
        
        # Save individual strategy CSV
        memory_data = {
            'steps': agent.memory.steps,
            'training_values': agent.memory.training_values,
            'eval_values': agent.memory.eval_values
        }
        game_metrics = getattr(agent.memory, 'game_theory_metrics', None)
        csv_saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)
        
        print(f"\nStrategy {strategy_name.upper()} completed and CSV saved!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")
    
    # Save combined CSV files
    print(f"\n{'='*60}")
    print(f"SAVING COMBINED CSV FILES")
    print(f"{'='*60}")
    
    csv_saver.save_combined_performance_csv(all_strategy_results)
    csv_saver.create_plotting_ready_csv(all_strategy_results)
    
    # Generate comparative plots (your existing plotting code)
    if 'agent' in locals():
        agent.save_training_plots(opts.save_dir, all_strategy_results)
    
    return all_strategy_results
class Config:
    """Configuration class for easy parameter modification"""
    
    # Environment and run settings
    env_name = 'LunarLander-v2'  # Options: 'CartPole-v1', 'HalfCheetah-v2', 'LunarLander-v2'
    eval_only = False
    no_saving = False
    no_tb = False
    render = False
    mode = 'human'  # 'human' or 'rgb'
    log_dir = 'logs'
    run_name = 'run_name'
    
    # Multiple runs
    multiple_run = 1
    seed = 2
    
    # Federation parameters
    num_worker = 25
    alpha = 0.4
    num_Byzantine = 0
    # RL Algorithms
    SVRPG = False
    FedPG_BR = False
    attack_type = None
    # Training and validation
    val_size = 10
    val_max_steps = 1000
    load_path = None
    
    # Device settings
    use_cuda = False
    
    # ======================== GAME THEORY PARAMETERS ========================
    # Game theory integration parameters
    enable_game_theory = True
    compare_all_strategies = True  # Run all strategies and compare results
    strategies_to_compare = ['baseline', 'random', 'highest_risk', 'highest_value', 'true_stackelberg']
    initial_defense_budget_ratio = 0.05
    initial_attack_budget_ratio = 0.3
    defense_effectiveness = 1.0
    base_attack_intensity = 0.8
    damage_scenario = "tiered_importance"  # "critical_infrastructure", "tiered_importance", "network_hubs"
    reshuffle_frequency = 500 # Reshuffe game parameters every N rounds
    single_best_attacker_only = False


    def __init__(self):
        # Set environment-specific hyperparameters
        self.set_environment_params()
        
        # Setup paths and run name
        self.setup_paths()
        
    def set_environment_params(self):
        """Set hyperparameters based on environment"""
        if self.env_name == 'CartPole-v1':
            # Task-Specified Hyperparameters
            self.max_epi_len = 500
            self.max_trajectories = 5000
            self.gamma = 0.999
            self.min_reward = 0
            self.max_reward = 600

            # Shared parameters
            self.do_sample_for_training = True
            self.lr_model = 1e-3
            self.hidden_units = '16,16'
            self.activation = 'ReLU'
            self.output_activation = 'Tanh'
            
            # Batch sizes
            self.B = 16  # for SVRPG and GOMDP
            self.Bmin = 12  # for FT-FedScsPG
            self.Bmax = 20  # for FT-FedScsPG
            self.b = 4  # mini batch_size for SVRPG and FT-FedScsPG
            
            # Inner loop iteration for SVRPG
            self.N = 3
            
            # Filtering hyperparameters for FT-FedScsPG
            self.delta = 0.6
            self.sigma = 0.06

        elif self.env_name == 'HalfCheetah-v2':
            self.max_epi_len = 500  
            self.max_trajectories = 1e4
            self.gamma = 0.995
            self.min_reward = 0
            self.max_reward = 4000
            
            self.do_sample_for_training = True
            self.lr_model = 8e-5
            self.hidden_units = '64,64'
            self.activation = 'Tanh'
            self.output_activation = 'Tanh'
           
            self.B = 48
            self.Bmin = 46
            self.Bmax = 50
            self.b = 16
            self.N = 3
            self.delta = 0.6
            self.sigma = 0.9

        elif self.env_name == 'LunarLander-v2':
            self.max_epi_len = 1000  
            self.max_trajectories = 1e4
            self.gamma = 0.99
            self.min_reward = -1000
            self.max_reward = 300
            
            self.do_sample_for_training = True
            self.lr_model = 1e-3
            self.hidden_units = '64,64'
            self.activation = 'Tanh'
            self.output_activation = 'Tanh'
            
            self.B = 32
            self.Bmin = 26
            self.Bmax = 38
            self.b = 8
            self.N = 3
            self.delta = 0.6
            self.sigma = 0.07
            
    def setup_paths(self):
        """Setup directory paths"""
        
        if not self.no_saving:
            self.save_dir = os.path.join(
                f'outputs_{self.env_name}_shuffle_{self.reshuffle_frequency}_single_attacker_{self.single_best_attacker_only}_attack_{self.base_attack_intensity}_seed{self.seed}',
            )
        else:
            self.save_dir = None
            
        if not self.no_tb:
            self.log_dir = os.path.join(
                f'{self.log_dir}_{self.env_name}_shuffle_{self.reshuffle_frequency}_single_attacker_{self.single_best_attacker_only}_attack_{self.base_attack_intensity}_seed{self.seed}',
            )
        else:
            self.log_dir = None
            
        # Set device
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        # Set seeds
        self.seeds = (np.arange(self.multiple_run) + self.seed).tolist()

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def torch_load_cpu(load_path):
    return torch.load(load_path, map_location=lambda storage, loc: storage)

def get_inner_model(model):
    return model.module if isinstance(model, DataParallel) else model

def move_to(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    return var.to(device)

def env_wrapper(name, obs):
    return obs

def save_frames_as_gif(frames, path='./', filename='gym_animation.gif'):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save(path + filename, writer='imagemagick', fps=120)

def euclidean_dist(x, y):
    """
    Args:
      x: pytorch Variable, with shape [m, d]
      y: pytorch Variable, with shape [n, d]
    Returns:
      dist: pytorch Variable, with shape [m, n]
    """
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).T
    dist = xx + yy
    dist.addmm_(1, -2, x, y.T)
    dist[dist < 0] = 0
    dist = dist.sqrt()
    return dist

# ======================== GAME THEORY UTILITY FUNCTIONS ========================

def fedavg(parameters_list):
    """Federated averaging of model parameters"""
    with torch.no_grad():
        new_params = {}
        for name in parameters_list[0].keys():
            new_params[name] = torch.zeros_like(parameters_list[0][name])
            for param in parameters_list:
                new_params[name] += param[name]
            new_params[name] /= len(parameters_list)
    return new_params

def create_realistic_unequal_damage_weights(num_clients: int, scenario: str = "critical_infrastructure", seed: int = 42) -> np.ndarray:
    """
    Create realistic unequal damage weights that will be reflected in attack severity
    
    Even though all clients have equal data, some are more strategically valuable
    """
    np.random.seed(seed)
    
    if scenario == "critical_infrastructure":
        # Some clients are critical infrastructure (hospitals, banks)
        # Others are regular IoT devices
        damage_weights = np.ones(num_clients) * 3.0  # Base damage for regular clients
        
        # Randomly assign some clients as critical (20% of clients)
        num_critical = max(1, num_clients // 5)
        critical_indices = np.random.choice(num_clients, num_critical, replace=False)
        
        for idx in critical_indices:
            damage_weights[idx] = 15.0  # Critical clients have 5x more strategic value
        
        print(f"CRITICAL INFRASTRUCTURE SCENARIO:")
        print(f"   Critical clients: {sorted(critical_indices)}")
        print(f"   Regular damage weight: 3.0, Critical damage weight: 15.0")
        
    elif scenario == "tiered_importance":
        # Create 3 tiers of importance
        damage_weights = np.ones(num_clients) * 2.0  # Tier 3 (low value)
        
        # Tier 2 (medium value) - 30% of clients
        tier2_count = max(1, int(num_clients * 0.3))
        tier2_indices = np.random.choice(num_clients, tier2_count, replace=False)
        for idx in tier2_indices:
            damage_weights[idx] = 8.0
        
        # Tier 1 (high value) - 20% of clients  
        remaining_indices = [i for i in range(num_clients) if i not in tier2_indices]
        tier1_count = max(1, int(num_clients * 0.2))
        tier1_indices = np.random.choice(remaining_indices, tier1_count, replace=False)
        for idx in tier1_indices:
            damage_weights[idx] = 20.0
            
        print(f"TIERED IMPORTANCE SCENARIO:")
        print(f"   Tier 1 (high value): {sorted(tier1_indices)} → weight: 20.0")
        print(f"   Tier 2 (medium value): {sorted(tier2_indices)} → weight: 8.0") 
        print(f"   Tier 3 (low value): remaining → weight: 2.0")
        
    elif scenario == "network_hubs":
        # Some clients are network hubs with higher connectivity
        damage_weights = np.ones(num_clients) * 4.0  # Base damage
        
        # Select hub clients (15% of total)
        num_hubs = max(1, int(num_clients * 0.15))
        hub_indices = np.random.choice(num_clients, num_hubs, replace=False)
        
        for idx in hub_indices:
            damage_weights[idx] = 18.0  # Hub clients have 4.5x more impact
        
        print(f"NETWORK HUBS SCENARIO:")
        print(f"   Hub clients: {sorted(hub_indices)}")
        print(f"   Regular damage weight: 4.0, Hub damage weight: 18.0")
        
    else:  # default to critical_infrastructure
        return create_realistic_unequal_damage_weights(num_clients, "critical_infrastructure", seed)
    
    print(f"   Damage range: {damage_weights.min():.1f} to {damage_weights.max():.1f}")
    print(f"   Ratio: {damage_weights.max()/damage_weights.min():.1f}x difference")
    
    return damage_weights

def create_strategic_variety_costs(num_clients: int, base_values: np.ndarray, 
                                 cost_multiplier_range: tuple = (0.5, 1.5), 
                                 seed: int = 42) -> np.ndarray:
    """
    Create cost variety for strategic game dynamics while keeping things simple
    
    Args:
        base_values: Base damage weights (now unequal!)
        cost_multiplier_range: Range for cost variation
        seed: Random seed for reproducibility
    """
    np.random.seed(seed)
    
    # Create cost variation around base values for strategic interest
    cost_multipliers = np.random.uniform(cost_multiplier_range[0], cost_multiplier_range[1], num_clients)
    costs = base_values * cost_multipliers
    
    # Add some additional random variation
    noise = np.random.normal(0, 0.2, num_clients)
    costs += noise
    
    # Ensure positive costs
    costs = np.maximum(costs, 0.1)
    
    return costs

# ======================== KNAPSACK PROBLEM SOLVERS ========================

def solve_knapsack_dp(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """
    Solve 0/1 knapsack using dynamic programming efficiently
    
    FIXES APPLIED:
    - Corrected backtracking logic for optimal solutions
    - Integer overflow protection with dynamic scaling
    - Memory optimization for large instances
    """
    n = len(values)
    
    if n == 0:
        return set()
    
    # FIX #2: Integer overflow protection with dynamic scaling
    max_value = max(budget, costs.max()) if len(costs) > 0 else budget
    safe_scale_factor = min(1000, int(2**30 / max_value)) if max_value > 0 else 1000
    
    budget_int = int(budget * safe_scale_factor)
    costs_int = (costs * safe_scale_factor).astype(np.int64)  # Use int64 for safety
    
    # FIX #3: Memory optimization - use space-efficient DP for large instances
    memory_threshold = 50_000_000  # ~50MB threshold
    use_space_optimized = (n + 1) * (budget_int + 1) > memory_threshold
    
    if use_space_optimized:
        # Space-optimized O(budget) memory approach
        return _solve_knapsack_space_optimized(values, costs_int, budget_int, safe_scale_factor)
    else:
        # Standard O(n × budget) memory approach with full backtracking
        return _solve_knapsack_full_table(values, costs_int, budget_int)

def _solve_knapsack_full_table(values: np.ndarray, costs_int: np.ndarray, budget_int: int) -> Set[int]:
    """Standard DP with full table for smaller instances"""
    n = len(values)
    
    # Full DP table
    dp = np.zeros((n + 1, budget_int + 1), dtype=np.float32)  # Use float32 to save memory
    
    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            # Don't take item i-1
            dp[i][w] = dp[i-1][w]
            
            # Take item i-1 if profitable
            if costs_int[i-1] <= w and values[i-1] > 0:
                dp[i][w] = max(dp[i][w], dp[i-1][w - costs_int[i-1]] + values[i-1])
    
    # Backtracking to find selected items
    selected = set()
    w = budget_int
    for i in range(n, 0, -1):
        # Check if item i-1 was actually taken in optimal solution
        if w >= costs_int[i-1] and abs(dp[i][w] - (dp[i-1][w - costs_int[i-1]] + values[i-1])) < 1e-9:
            selected.add(i-1)
            w -= costs_int[i-1]
    
    return selected

def _solve_knapsack_space_optimized(values: np.ndarray, costs_int: np.ndarray, budget_int: int, scale_factor: int) -> Set[int]:
    """
    Space-optimized DP using O(budget) memory
    Runs algorithm twice: once for value, once for solution reconstruction
    """
    n = len(values)
    
    # First pass: compute optimal value using only two rows
    prev_row = np.zeros(budget_int + 1, dtype=np.float32)
    curr_row = np.zeros(budget_int + 1, dtype=np.float32)
    
    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            # Don't take item i-1
            curr_row[w] = prev_row[w]
            
            # Take item i-1 if profitable
            if costs_int[i-1] <= w and values[i-1] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i-1]] + values[i-1])
        
        # Swap rows
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)
    
    optimal_value = prev_row[budget_int]
    
    # Second pass: reconstruct solution by testing which items were taken
    selected = set()
    remaining_budget = budget_int
    remaining_value = optimal_value
    
    # Test items in reverse order
    for i in range(n - 1, -1, -1):
        if values[i] <= 0 or costs_int[i] > remaining_budget:
            continue
            
        # Check if removing this item reduces the optimal value
        # by running mini-DP without this item
        temp_budget = remaining_budget - costs_int[i]
        temp_value = remaining_value - values[i]
        
        # Quick check: if including this item achieves the remaining value target
        if abs(temp_value - _compute_optimal_value_without_item(
            values, costs_int, temp_budget, exclude_item=i)) < 1e-9:
            selected.add(i)
            remaining_budget = temp_budget
            remaining_value = temp_value
    
    return selected

def _compute_optimal_value_without_item(values: np.ndarray, costs_int: np.ndarray, 
                                      budget: int, exclude_item: int) -> float:
    """Helper function to compute optimal value excluding one specific item"""
    if budget <= 0:
        return 0.0
        
    prev_row = np.zeros(budget + 1, dtype=np.float32)
    curr_row = np.zeros(budget + 1, dtype=np.float32)
    
    for i in range(len(values)):
        if i == exclude_item:
            continue  # Skip the excluded item
            
        for w in range(budget + 1):
            curr_row[w] = prev_row[w]
            if costs_int[i] <= w and values[i] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i]] + values[i])
        
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)
    
    return prev_row[budget]

def solve_knapsack_greedy(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """Solve knapsack using greedy heuristic efficiently"""
    # Calculate efficiency (value per unit cost)
    efficiency = np.divide(values, costs, out=np.zeros_like(values), where=costs!=0)
    
    # Sort by efficiency (descending)
    sorted_indices = np.argsort(efficiency)[::-1]
    
    selected = set()
    current_cost = 0.0
    
    for idx in sorted_indices:
        if current_cost + costs[idx] <= budget and values[idx] > 0:
            selected.add(idx)
            current_cost += costs[idx]
    
    return selected

def generate_feasible_defense_combinations(defense_costs: np.ndarray, budget: float, max_combinations: int = 10000):
    """Generate feasible defense combinations within budget using smart enumeration"""
    num_clients = len(defense_costs)
    feasible_combinations = []
    
    # Start with greedy solution and variations
    efficiency = 1.0 / defense_costs  # Simple efficiency metric
    sorted_indices = np.argsort(efficiency)[::-1]
    
    # Method 1: Greedy and variations
    current_cost = 0.0
    current_combo = set()
    for idx in sorted_indices:
        if current_cost + defense_costs[idx] <= budget:
            current_combo.add(idx)
            current_cost += defense_costs[idx]
    feasible_combinations.append(current_combo)
    
    # Method 2: Generate variations by swapping elements
    base_combo = current_combo.copy()
    for _ in range(min(1000, max_combinations // 10)):
        # Try removing one element and adding another
        if len(base_combo) > 0:
            to_remove = random.choice(list(base_combo))
            new_combo = base_combo - {to_remove}
            new_cost = sum(defense_costs[i] for i in new_combo)
            
            # Try adding a different element
            remaining = set(range(num_clients)) - new_combo
            for candidate in remaining:
                if new_cost + defense_costs[candidate] <= budget:
                    candidate_combo = new_combo | {candidate}
                    if candidate_combo not in feasible_combinations:
                        feasible_combinations.append(candidate_combo)
                        if len(feasible_combinations) >= max_combinations:
                            return feasible_combinations
    
    # Method 3: Random sampling for smaller instances
    if num_clients <= 20:
        for _ in range(min(2000, max_combinations - len(feasible_combinations))):
            # Random subset
            subset_size = random.randint(1, min(num_clients, 10))
            subset = set(random.sample(range(num_clients), subset_size))
            cost = sum(defense_costs[i] for i in subset)
            if cost <= budget and subset not in feasible_combinations:
                feasible_combinations.append(subset)
    
    return feasible_combinations

def solve_knapsack_simple_greedy_debug(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """Simple greedy solver with debug output"""
    print(f"   Greedy Debug:")
    print(f"   Values: {values}")
    print(f"   Costs: {costs}")
    print(f"   Budget: {budget}")
    
    # Calculate efficiency (value per unit cost)
    efficiency = np.divide(values, costs, out=np.zeros_like(values), where=costs!=0)
    print(f"   Efficiency (value/cost): {efficiency}")
    
    # Sort by efficiency (descending)  
    sorted_indices = np.argsort(efficiency)[::-1]
    print(f"   Sorted by efficiency: {sorted_indices}")
    
    selected = set()
    current_cost = 0.0
    
    for idx in sorted_indices:
        if values[idx] > 0 and current_cost + costs[idx] <= budget:
            selected.add(idx)
            current_cost += costs[idx]
            print(f"   Selected item {idx}: value={values[idx]:.4f}, cost={costs[idx]:.4f}, total_cost={current_cost:.4f}")
        else:
            print(f"   Skipped item {idx}: value={values[idx]:.4f}, cost={costs[idx]:.4f} (would exceed budget)")
    
    print(f"   Final selection: {selected}")
    print(f"   Total cost: {current_cost:.4f}/{budget:.4f}")
    print(f"   Total value: {sum(values[i] for i in selected):.4f}")
    
    return selected

# ======================== GAME THEORY CLASSES ========================

class BudgetConstrainedAttacker:
    """
    Budget-constrained attacker that solves knapsack problem to maximize utility
    
    CRITICAL FIXES APPLIED:
    - Perfect defense utility (-inf for defended clients) ensures Stackelberg property
    - Integer overflow protection and memory optimization for large instances
    """
    
    def __init__(self, num_clients: int, damage_weights: np.ndarray, attack_costs: np.ndarray, 
             attack_budget: float, single_attacker_mode: bool = False):
        self.num_clients = num_clients
        self.damage_weights = damage_weights.copy()
        self.attack_costs = attack_costs.copy()
        self.attack_budget = attack_budget
        self.single_attacker_mode = single_attacker_mode  # NEW PARAMETER

        
    def update_parameters(self, damage_weights: np.ndarray, attack_costs: np.ndarray, 
                     attack_budget: float, single_attacker_mode: bool = None):
        """Update game parameters for dynamic reshuffling"""
        self.damage_weights = damage_weights.copy()
        self.attack_costs = attack_costs.copy()
        self.attack_budget = attack_budget
        if single_attacker_mode is not None:
            self.single_attacker_mode = single_attacker_mode
    
    def optimal_response(self, defended_clients: Set[int]) -> Set[int]:
        """
        Solve attacker's knapsack problem given defense strategy
        In Stackelberg games, attacker will NEVER attack defended clients
        """
        # Calculate net utilities for each client
        utilities = np.zeros(self.num_clients)
        for i in range(self.num_clients):
            if i in defended_clients:
                utilities[i] = float('-inf')  # Attack completely blocked
            else:
                utilities[i] = self.damage_weights[i] - self.attack_costs[i]

        print(f"DEBUG ATTACKER:")
        print(f"   Defended clients: {defended_clients}")
        print(f"   Attack budget: {self.attack_budget:.2f}")
        print(f"   Single attacker mode: {getattr(self, 'single_attacker_mode', False)}")
        
        # Only attack clients with positive utility
        positive_utility_mask = utilities > 0
        if not np.any(positive_utility_mask):
            return set()

        valid_indices = np.where(positive_utility_mask)[0]
        valid_utilities = utilities[positive_utility_mask]
        valid_costs = self.attack_costs[positive_utility_mask]
        
        # NEW LOGIC: Check if single attacker mode is enabled
        if getattr(self, 'single_attacker_mode', False):
            # Select only the single best attacker (highest utility that fits budget)
            best_attacker_idx = None
            best_utility = -float('inf')
            
            for i, (utility, cost) in enumerate(zip(valid_utilities, valid_costs)):
                if cost <= self.attack_budget and utility > best_utility:
                    best_utility = utility
                    best_attacker_idx = i
            
            if best_attacker_idx is not None:
                attacked_clients = {valid_indices[best_attacker_idx]}
                print(f"   Single best attacker selected: {attacked_clients}")
                return attacked_clients
            else:
                print(f"   No attacker within budget in single attacker mode")
                return set()
        
        else:
            # Original multi-attacker logic using knapsack
            print(f"   Valid indices: {valid_indices}")
            print(f"   Valid utilities: {valid_utilities}")
            print(f"   Valid costs: {valid_costs}")
            
            if self.num_clients <= 25:
                selected_valid = solve_knapsack_simple_greedy_debug(valid_utilities, valid_costs, self.attack_budget)
            else:
                selected_valid = solve_knapsack_simple_greedy_debug(valid_utilities, valid_costs, self.attack_budget)
            
            attacked_clients = {valid_indices[i] for i in selected_valid}
            print(f"   Multiple attackers selected: {attacked_clients}")
            return attacked_clients
    
    def get_attack_utilities(self, defended_clients: Set[int]) -> np.ndarray:
        """Get utility for attacking each client (for analysis)"""
        utilities = np.zeros(self.num_clients)
        for i in range(self.num_clients):
            if i in defended_clients:
                # FIXED: Consistent with optimal_response - defended attacks have -inf utility
                utilities[i] = float('-inf')  # Attack completely blocked
            else:
                utilities[i] = self.damage_weights[i] - self.attack_costs[i]
        return utilities

class EfficientStackelbergSolver:
    """
    FIXED: Mathematically correct Stackelberg game solver using proper timeline
    Defender commits first, attacker responds optimally
    """
    
    def __init__(self, num_clients: int, damage_weights: np.ndarray, 
                 defense_costs: np.ndarray, attack_costs: np.ndarray,
                 defense_budget: float, attack_budget: float):
        self.num_clients = num_clients
        self.damage_weights = damage_weights
        self.defense_costs = defense_costs
        self.attack_costs = attack_costs
        self.defense_budget = defense_budget
        self.attack_budget = attack_budget
        
    def solve_stackelberg_exact(self) -> Set[int]:
        """
        FIXED: Exact Stackelberg solution using proper timeline
        1. Enumerate feasible defense strategies
        2. For each defense, compute attacker's optimal response  
        3. Choose defense that minimizes resulting damage
        """
        print(f"Solving Stackelberg game exactly...")
        
        # Generate feasible defense combinations
        feasible_defenses = generate_feasible_defense_combinations(
            self.defense_costs, self.defense_budget, max_combinations=5000
        )
        
        print(f"   Generated {len(feasible_defenses)} feasible defense combinations")
        
        best_defense = set()
        best_damage = float('inf')
        
        for defense_set in feasible_defenses:
            # CORRECT STACKELBERG TIMELINE:
            # 1. Defender commits to this defense
            # 2. Attacker observes and responds optimally
            attack_set = self._solve_attacker_knapsack(defense_set)
            
            # 3. Calculate resulting damage (only undefended successful attacks)
            damage = 0
            for attacker_id in attack_set:
                if attacker_id not in defense_set:  # Only undefended attacks cause damage
                    damage += self.damage_weights[attacker_id]
            defense_cost = sum(self.defense_costs[i] for i in defense_set)
            total_cost = damage + defense_cost
            if total_cost < best_damage:
                best_damage = total_cost
                best_defense = defense_set
        
        print(f"   Best defense found: {len(best_defense)} clients defended")
        print(f"   Resulting damage: {best_damage:.2f}")
        
        return best_defense
        
    def solve_stackelberg_bilevel(self) -> Set[int]:
        """
        FIXED: Bilevel programming approach with proper penalty scaling
        """
        
        def defender_objective(defense_vars):
            # Convert continuous to binary (rounding)
            defense_binary = (defense_vars > 0.5).astype(int)
            defense_set = {i for i in range(self.num_clients) if defense_binary[i] == 1}
            
            # Calculate defense cost
            cost = sum(self.defense_costs[i] for i in defense_set)
            
            # CORRECT STACKELBERG TIMELINE:
            # Get attacker's optimal response to this defense
            attack_set = self._solve_attacker_knapsack(defense_set)
            
            # Calculate resulting damage
            damage = 0
            for attacker_id in attack_set:
                if attacker_id not in defense_set:  # Only undefended attacks cause damage
                    damage += self.damage_weights[attacker_id]
            
            total_cost = damage + cost
            
            # FIXED: Proper penalty scaling relative to damage magnitude
            if cost > self.defense_budget:
                budget_violation = cost - self.defense_budget
                # Scale penalty relative to typical damage values
                max_possible_damage = sum(self.damage_weights)
                penalty = max_possible_damage * 10 * (budget_violation / self.defense_budget)
                return total_cost + penalty
            
            return total_cost
        
        # Constraint: defense budget
        def budget_constraint(defense_vars):
            return self.defense_budget - np.dot(defense_vars, self.defense_costs)
        
        # Initial guess: greedy defense
        x0 = self._greedy_defense_vector()
        
        # Solve using constrained optimization
        constraints = {'type': 'ineq', 'fun': budget_constraint}
        bounds = [(0, 1) for _ in range(self.num_clients)]  # Binary relaxation
        
        result = minimize(
            defender_objective, x0, method='SLSQP',
            bounds=bounds, constraints=constraints,
            options={'maxiter': 200, 'disp': False}
        )
        
        if result.success:
            # Round to binary solution
            defense_binary = (result.x > 0.5).astype(int)
            defense_set = {i for i in range(self.num_clients) if defense_binary[i] == 1}
            
            # Verify budget feasibility
            actual_cost = sum(self.defense_costs[i] for i in defense_set)
            if actual_cost <= self.defense_budget:
                return defense_set
        
        # Fallback to exact method for small instances or greedy for large
        if self.num_clients < 25:
            return self.solve_stackelberg_exact()
        else:
            return self._greedy_defense_initialization()
    
    def solve_stackelberg_smart(self) -> Set[int]:
        """
        FIXED: Smart hybrid approach - exact for small instances, bilevel for large
        """
        if self.num_clients <= 20:
            # Use exact method for small instances
            return self.solve_stackelberg_exact()
        else:
            # Use bilevel optimization for larger instances
            return self.solve_stackelberg_bilevel()
    
    def _solve_attacker_knapsack(self, defended_clients: Set[int]) -> Set[int]:
        """Efficiently solve attacker's knapsack problem (unchanged - this was correct)"""
        # Calculate utilities
        utilities = np.zeros(self.num_clients)
        for i in range(self.num_clients):
            if i in defended_clients:
                # FIXED: Consistent with BudgetConstrainedAttacker - defended attacks have -inf utility
                utilities[i] = float('-inf')  # Attack completely blocked
            else:
                utilities[i] = self.damage_weights[i] - self.attack_costs[i]
        
        # Only consider profitable attacks
        profitable = utilities > 0
        if not np.any(profitable):
            return set()
        
        # Solve knapsack on profitable clients
        profitable_indices = np.where(profitable)[0]
        profitable_utilities = utilities[profitable]
        profitable_costs = self.attack_costs[profitable]
        
        # Use appropriate algorithm based on size
        if len(profitable_indices) <= 25:
            selected = solve_knapsack_dp(profitable_utilities, profitable_costs, self.attack_budget)
        else:
            selected = solve_knapsack_greedy(profitable_utilities, profitable_costs, self.attack_budget)
        
        # Map back to original indices
        return {profitable_indices[i] for i in selected}
    
    def _greedy_defense_initialization(self) -> Set[int]:
        """Generate efficient initial defense (unchanged - this was correct)"""
        efficiency = self.damage_weights / self.defense_costs
        sorted_indices = np.argsort(efficiency)[::-1]
        
        defense = set()
        current_cost = 0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                defense.add(idx)
                current_cost += self.defense_costs[idx]
        
        return defense
    
    def _greedy_defense_vector(self) -> np.ndarray:
        """Greedy initialization as vector for optimization (unchanged)"""
        efficiency = self.damage_weights / self.defense_costs
        sorted_indices = np.argsort(efficiency)[::-1]
        
        x0 = np.zeros(self.num_clients)
        current_cost = 0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                x0[idx] = 1
                current_cost += self.defense_costs[idx]
        
        return x0

class MathematicalStackelbergDefender:
    """
    FIXED: Efficient mathematical Stackelberg defender with correct timeline
    """
    
    def __init__(self, num_clients: int, damage_weights: np.ndarray, defense_costs: np.ndarray, 
                 defense_budget: float, attack_costs: np.ndarray, attack_budget: float, 
                 defense_effectiveness: float = 1.0):
        self.num_clients = num_clients
        self.damage_weights = damage_weights.copy()
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        self.defense_effectiveness = defense_effectiveness
        
        # Create fixed solver
        self.solver = EfficientStackelbergSolver(
            num_clients, damage_weights, defense_costs, 
            attack_costs, defense_budget, attack_budget
        )
        
    def solve_stackelberg_equilibrium(self) -> Set[int]:
        """
        FIXED: Solve Stackelberg game with correct timeline
        """
        start_time = time.time()
        
        # Use the smart hybrid approach
        solution = self.solver.solve_stackelberg_smart()
        
        solve_time = time.time() - start_time
        print(f"   Stackelberg solved in {solve_time:.3f}s")
        
        return solution
    
    def update_parameters(self, damage_weights: np.ndarray, defense_costs: np.ndarray, 
                         defense_budget: float, attack_costs: np.ndarray, attack_budget: float):
        """Update parameters for dynamic reshuffling"""
        self.damage_weights = damage_weights.copy()
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
        # Update solver
        self.solver = EfficientStackelbergSolver(
            self.num_clients, damage_weights, defense_costs,
            attack_costs, defense_budget, attack_budget
        )

class AttackProbabilityEstimator:
    """Estimates attack probabilities for each client based on history"""
    
    def __init__(self, num_clients: int, alpha: float = 0.3):
        self.num_clients = num_clients
        self.alpha = alpha
        self.attack_history = defaultdict(list)
        self.current_probs = np.ones(num_clients) * 0.1
        
    def update_history(self, round_num: int, attackers: Set[int]):
        """Update attack history for this round"""
        for client_id in range(self.num_clients):
            is_attacker = 1 if client_id in attackers else 0
            self.attack_history[client_id].append(is_attacker)
            
        # Update probabilities using exponential smoothing
        for client_id in range(self.num_clients):
            if len(self.attack_history[client_id]) > 0:
                recent_rate = np.mean(self.attack_history[client_id][-10:])
                self.current_probs[client_id] = (
                    self.alpha * recent_rate + 
                    (1 - self.alpha) * self.current_probs[client_id]
                )
    
    def get_probabilities(self) -> np.ndarray:
        return self.current_probs.copy()
    
    def reset_for_new_era(self):
        """Reset learning when game parameters change"""
        self.attack_history = defaultdict(list)
        self.current_probs = np.ones(self.num_clients) * 0.1

class DefenseStrategy:
    """Base class for budget-constrained defense strategies"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float):
        self.num_clients = num_clients
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float):
        """Update parameters for dynamic reshuffling"""
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        raise NotImplementedError

class RandomDefense(DefenseStrategy):
    """Random defense within budget constraints"""
    
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        # Randomly select clients that fit within budget
        clients = list(range(self.num_clients))
        random.shuffle(clients)
        
        selected = set()
        current_cost = 0.0
        
        for client_id in clients:
            if current_cost + self.defense_costs[client_id] <= self.defense_budget:
                selected.add(client_id)
                current_cost += self.defense_costs[client_id]
        
        return selected

class HighestRiskDefense(DefenseStrategy):
    """Defend clients with highest learned attack probabilities (budget-constrained)"""
    
    def select_defended_clients(self, round_num: int, attack_probs: np.ndarray = None, **kwargs) -> Set[int]:
        if attack_probs is None:
            attack_probs = np.ones(self.num_clients) * 0.1
        
        # Sort by attack probabilities only
        sorted_indices = np.argsort(attack_probs)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class HighestValueDefense(DefenseStrategy):
    """Defend clients with highest damage potential within budget"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        # Sort by damage_weights only (highest first)
        sorted_indices = np.argsort(self.damage_weights)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class NaiveStackelbergDefense(DefenseStrategy):
    """Risk-based prioritization within budget (not true game theory)"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        
    def select_defended_clients(self, round_num: int, attack_probs: np.ndarray = None, **kwargs) -> Set[int]:
        if attack_probs is None:
            attack_probs = np.ones(self.num_clients) * 0.1
        
        risk_scores = attack_probs * self.damage_weights
        
        # Sort by risk_scores only (highest first)
        sorted_indices = np.argsort(risk_scores)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class EfficientTrueStackelbergDefense(DefenseStrategy):
    """FIXED: True game-theoretic Stackelberg defense using correct mathematical formulation"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, 
                 damage_weights: np.ndarray, attack_costs: np.ndarray, attack_budget: float, 
                 defense_effectiveness: float = 1.0):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.defender = MathematicalStackelbergDefender(
            num_clients, damage_weights, defense_costs, defense_budget, 
            attack_costs, attack_budget, defense_effectiveness
        )
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, 
                         damage_weights: np.ndarray = None, attack_costs: np.ndarray = None, 
                         attack_budget: float = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None and attack_costs is not None and attack_budget is not None:
            self.defender.update_parameters(damage_weights, defense_costs, defense_budget, 
                                          attack_costs, attack_budget)
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        return self.defender.solve_stackelberg_equilibrium()

class GameTheoryIntegrator:
    """Integrates game theory logic into the FRL system"""
    
    def __init__(self, opts):
        self.opts = opts
        
        if not opts.enable_game_theory:
            return
            
        print(f"\n{'='*80}")
        print(f"INITIALIZING STACKELBERG GAME THEORY FOR FRL")
        print(f"{'='*80}")
        
        # Create damage weights based on scenario
        self.current_damage_weights = create_realistic_unequal_damage_weights(
            opts.num_worker, opts.damage_scenario, opts.seed
        )
        
        # Create strategic variety in costs
        self.current_defense_costs = create_strategic_variety_costs(
            opts.num_worker, self.current_damage_weights, (0.5, 1.5), opts.seed
        )
        self.current_attack_costs = create_strategic_variety_costs(
            opts.num_worker, self.current_damage_weights, (0.5, 1.5), opts.seed + 1
        )
        
        # Set budgets
        total_defense_cost = sum(self.current_defense_costs)
        total_attack_cost = sum(self.current_attack_costs)
        self.current_defense_budget = opts.initial_defense_budget_ratio * total_defense_cost
        self.current_attack_budget = opts.initial_attack_budget_ratio * total_attack_cost
        
        # Store initial budgets to keep them fixed
        self.fixed_defense_budget = self.current_defense_budget
        self.fixed_attack_budget = self.current_attack_budget
        
        # Initialize game components
        self.prob_estimator = AttackProbabilityEstimator(opts.num_worker)
        self.attacker = BudgetConstrainedAttacker(
        opts.num_worker, 
        self.current_damage_weights, 
        self.current_attack_costs, 
        self.current_attack_budget,
        single_attacker_mode=opts.single_best_attacker_only  # NEW PARAMETER
    )
        
        # Track eras
        self.current_era = 0
        
        # Initialize defense strategies
        self.strategies = {
            'random': RandomDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget),
            'highest_risk': HighestRiskDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget),
            'highest_value': HighestValueDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights),
            'naive_stackelberg': NaiveStackelbergDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights),
            'true_stackelberg': EfficientTrueStackelbergDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, 
                                                               self.current_damage_weights, self.current_attack_costs, self.current_attack_budget, opts.defense_effectiveness)
        }
        
        # Set default strategy
        self.current_strategy = 'true_stackelberg'
        
        print(f"GAME THEORY PARAMETERS:")
        print(f"   Scenario: {opts.damage_scenario}")
        print(f"   Damage weights: {self.current_damage_weights.min():.1f} to {self.current_damage_weights.max():.1f}")
        print(f"   Defense budget: {self.current_defense_budget:.1f}")
        print(f"   Attack budget: {self.current_attack_budget:.1f}")
        print(f"   Defense effectiveness: {opts.defense_effectiveness:.1%}")
        print(f"   Reshuffle frequency: every {opts.reshuffle_frequency} rounds")
        print(f"   Current strategy: {self.current_strategy}")
        
    def reshuffle_game_parameters(self, round_num: int):
        """Reshuffle game parameters periodically"""
        if not self.opts.enable_game_theory:
            return
            
        self.current_era = round_num // self.opts.reshuffle_frequency
        era_seed = self.opts.seed + self.current_era * 1000
        
        print(f"\n{'='*60}")
        print(f"STRATEGIC RESHUFFLING! Round {round_num}, Era {self.current_era}")
        print(f"{'='*60}")
        
        # Keep damage weights consistent with scenario
        self.current_damage_weights = create_realistic_unequal_damage_weights(
            self.opts.num_worker, self.opts.damage_scenario, self.opts.seed
        )
        
        # Create new cost variations
        self.current_defense_costs = create_strategic_variety_costs(
            self.opts.num_worker, self.current_damage_weights, (0.5, 1.5), era_seed
        )
        self.current_attack_costs = create_strategic_variety_costs(
            self.opts.num_worker, self.current_damage_weights, (0.5, 1.5), era_seed + 1
        )
        
        # Keep budgets fixed
        self.current_defense_budget = self.fixed_defense_budget
        self.current_attack_budget = self.fixed_attack_budget
        
        # Update game components
        self.attacker.update_parameters(
        self.current_damage_weights, 
        self.current_attack_costs, 
        self.current_attack_budget,
        single_attacker_mode=self.attacker.single_attacker_mode  # Preserve setting
    )
        
        # Update defense strategies
        for strategy_name, strategy in self.strategies.items():
            if strategy_name in ['highest_value', 'naive_stackelberg']:
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget, self.current_damage_weights)
            elif strategy_name == 'true_stackelberg':
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget, 
                                         self.current_damage_weights, self.current_attack_costs, self.current_attack_budget)
            else:
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget)
        
        # Reset learning
        self.prob_estimator.reset_for_new_era()
        
        print(f"Reshuffling complete for era {self.current_era}")
        
    def get_attack_defense_decisions(self, round_num: int):
        """Get attack and defense decisions for current round"""
        if not self.opts.enable_game_theory:
            # No game theory - return empty sets
            return set(), set()
        no_attack_threshold = 1000  # Change this to 500 if you prefer
        if round_num >= (self.opts.max_trajectories - no_attack_threshold):
            print(f"FINAL PHASE: No attacks enabled (step {round_num}/{self.opts.max_trajectories})")
            
            # Still run defense strategy but return no attackers
            if self.current_strategy == 'true_stackelberg':
                defended_clients = self.strategies[self.current_strategy].select_defended_clients(round_num)
            else:
                attack_probs = self.prob_estimator.get_probabilities()
                defended_clients = self.strategies[self.current_strategy].select_defended_clients(
                    round_num, attack_probs=attack_probs
                )
            
            # Return empty attackers set (no attacks in final phase)
            return set(), defended_clients
        # Reshuffle if needed
        if round_num > 0 and round_num % self.opts.reshuffle_frequency == 0:
            self.reshuffle_game_parameters(round_num)
        
        # STEP 1: DEFENDER MOVES FIRST (Stackelberg timeline)
        if self.current_strategy == 'true_stackelberg':
            defended_clients = self.strategies[self.current_strategy].select_defended_clients(round_num)
        else:
            attack_probs = self.prob_estimator.get_probabilities()
            defended_clients = self.strategies[self.current_strategy].select_defended_clients(
                round_num, attack_probs=attack_probs
            )
        
        # STEP 2: ATTACKER OBSERVES AND RESPONDS OPTIMALLY
        attackers = self.attacker.optimal_response(defended_clients)
        
        # Update probability estimates
        self.prob_estimator.update_history(round_num, attackers)
        
        return attackers, defended_clients
    
    def calculate_damage_metrics(self, round_num: int, attackers: Set[int], defended_clients: Set[int], 
                               baseline_performance: float, actual_performance: float):
        """Calculate game theory damage metrics"""
        if not self.opts.enable_game_theory:
            return {}
        
        # Calculate performance drop
        performance_drop = max(0, baseline_performance - actual_performance)
        
        # Calculate theoretical damage based on damage weights
        theoretical_damage = sum(self.current_damage_weights[i] for i in attackers if i not in defended_clients)
        
        # Calculate costs
        defense_cost_used = sum(self.current_defense_costs[i] for i in defended_clients)
        attack_cost_used = sum(self.current_attack_costs[i] for i in attackers)
        
        # Get attack utilities
        attack_utilities = self.attacker.get_attack_utilities(defended_clients)
        avg_attacker_utility = np.mean([attack_utilities[i] for i in attackers]) if attackers else 0
        print(f"   DAMAGE DEBUG:")
        print(f"   Baseline performance: {baseline_performance:.3f}")
        print(f"   Actual performance: {actual_performance:.3f}")
        print(f"   Performance drop: {performance_drop:.3f}")
        return {
            'performance_drop': performance_drop,
            'theoretical_damage': theoretical_damage,
            'num_attackers': len(attackers),
            'num_defended': len(defended_clients),
            'successful_attacks': len(attackers),
            'blocked_attacks': len(attackers & defended_clients),
            'deterred_attacks': len(defended_clients),
            'avg_attacker_utility': avg_attacker_utility,
            'defense_cost_used': defense_cost_used,
            'attack_cost_used': attack_cost_used,
            'defense_budget_utilization': defense_cost_used / self.current_defense_budget if self.current_defense_budget > 0 else 0,
            'attack_budget_utilization': attack_cost_used / self.current_attack_budget if self.current_attack_budget > 0 else 0,
            'current_era': self.current_era,
            'era_round': round_num % self.opts.reshuffle_frequency,
            'deterrence_efficiency': len(defended_clients) / defense_cost_used if defense_cost_used > 0 else 0,
            'damage_weight_range': f"{self.current_damage_weights.min():.1f}-{self.current_damage_weights.max():.1f}",
            'current_strategy': self.current_strategy
        }

# =============================================================================
# POLICY NETWORKS
# =============================================================================

def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    """Build a feedforward neural network."""
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class MlpPolicy(nn.Module):
    def __init__(self, sizes, activation='Tanh', output_activation='Identity'):
        super(MlpPolicy, self).__init__()
        
        # Store parameters
        self.activation = activation
        self.output_activation = output_activation
        
        if activation == 'Tanh':
            self.activation = nn.Tanh
        elif activation == 'ReLU':
            self.activation = nn.ReLU
        else:
            raise NotImplementedError
            
        if output_activation == 'Identity':
            self.output_activation = nn.Identity
        elif output_activation == 'Tanh':
            self.output_activation = nn.Tanh
        elif output_activation == 'ReLU':
            self.output_activation = nn.ReLU
        elif output_activation == 'Softmax':
            self.output_activation = nn.Softmax
        else:
            raise NotImplementedError
            
        # Make policy network
        self.sizes = sizes
        self.logits_net = mlp(self.sizes, self.activation, self.output_activation)
        
        # Init parameters
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, obs, sample=True, fixed_action=None):
        obs = obs.view(-1)
        
        # Forward pass the policy net
        logits = self.logits_net(obs)
        
        # Get the policy dist
        policy = Categorical(logits=logits)
        
        # Take the pre-set action if given
        if fixed_action is not None:
            action = torch.tensor(fixed_action, device=obs.device)
        # Take random action
        elif sample:
            try:
                action = policy.sample()
            except:
                print(logits, obs)
        # Take greedy action
        else:
            action = policy.probs.argmax()
        
        return action.item(), policy.log_prob(action)

class DiagonalGaussianMlpPolicy(nn.Module):
    def __init__(self, sizes, activation='Tanh', output_activation='Tanh', geer=1):
        super(DiagonalGaussianMlpPolicy, self).__init__()

        # Store parameters
        self.activation = activation
        self.output_activation = output_activation

        if activation == 'Tanh':
            self.activation = nn.Tanh
        elif activation == 'ReLU':
            self.activation = nn.ReLU
        else:
            raise NotImplementedError

        # Make policy network
        self.sizes = sizes
        self.geer = geer
        self.logits_net = mlp(self.sizes[:-1], self.activation, nn.Identity)
        self.mu_net = nn.Linear(self.sizes[-2], self.sizes[-1], bias=False)
        self.log_sigma_net = nn.Linear(self.sizes[-2], self.sizes[-1], bias=False)
        self.LOG_SIGMA_MIN = -20
        self.LOG_SIGMA_MAX = -2
        
        # Init parameters
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, obs, sample=True, fixed_action=None):
        # Forward pass the policy net
        logits = self.logits_net(obs)

        # Get the mu
        mu = torch.tanh(self.mu_net(logits)) * self.geer

        # Get the sigma
        sigma = torch.tanh(torch.clamp(self.log_sigma_net(logits), self.LOG_SIGMA_MIN, self.LOG_SIGMA_MAX).exp())

        # Get the policy dist
        policy = Normal(mu, sigma)

        # Take the pre-set action
        if fixed_action is not None:
            action = torch.tensor(fixed_action, device=obs.device)
        else:
            if sample:
                action = policy.sample()
            else:
                action = mu.detach()
        
        # Avoid NaN
        ll = policy.log_prob(action)
        ll[ll < -1e5] = -1e5
        
        return action.numpy(), ll.sum()

# =============================================================================
# ENHANCED WORKER CLASS WITH GAME THEORY
# =============================================================================

class Worker:
    def __init__(self, id, env_name, hidden_units, gamma, 
                 activation='Tanh', output_activation='Identity', 
                 max_epi_len=0, opts=None):
        super(Worker, self).__init__()
        
        # Setup
        self.id = id
        self.gamma = gamma
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.max_epi_len = max_epi_len
        
        assert opts is not None
        self.opts = opts
        
        # Game theory attributes (only used when game theory is enabled)
        self.damage_weight = 1.0  
        self.is_currently_attacked = False
        self.is_currently_defended = False
        self.attack_intensity = 0.0
        
        # Get observation dim
        obs_dim = self.env.observation_space.shape[0]
        if isinstance(self.env.action_space, Discrete):
            n_acts = self.env.action_space.n
        else:
            n_acts = self.env.action_space.shape[0]
        
        hidden_sizes = list(eval(hidden_units))
        self.sizes = [obs_dim] + hidden_sizes + [n_acts]
        
        # Get policy net
        if isinstance(self.env.action_space, Discrete):
            self.logits_net = MlpPolicy(self.sizes, activation, output_activation)
        else:
            self.logits_net = DiagonalGaussianMlpPolicy(self.sizes, activation)
        
        if self.id == 1:
            print(self.logits_net)

    def update_game_status(self, damage_weight: float, is_attacked: bool, is_defended: bool, attack_intensity: float):
        """Update game theory status for this worker"""
        self.damage_weight = damage_weight
        self.is_currently_attacked = is_attacked
        self.is_currently_defended = is_defended
        self.attack_intensity = attack_intensity

    def load_param_from_master(self, param):
        model_actor = get_inner_model(self.logits_net)
        model_actor.load_state_dict({**model_actor.state_dict(), **param})

    def rollout(self, device, max_steps=1000, render=False, env=None, obs=None, 
                sample=True, mode='human', save_dir='./', filename='.'):
        
        if env is None and obs is None:
            env = self.env
            obs = env.reset()
            
        done = False  
        ep_rew = []
        frames = []
        step = 0
        while not done and step < max_steps:
            step += 1
            if render:
                if mode == 'rgb':
                    frames.append(env.render(mode="rgb_array"))
                else:
                    env.render()
                
            obs = env_wrapper(env.unwrapped.spec.id, obs)
            action = self.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(device), sample=sample)[0]
            obs, rew, done, _ = env.step(action)
            ep_rew.append(rew)

        if mode == 'rgb': 
            save_frames_as_gif(frames, save_dir, filename)
        return np.sum(ep_rew), len(ep_rew), ep_rew
    
    def collect_experience_for_training(self, B, device, record=False, sample=True):
        # Make some empty lists for logging
        batch_weights = []
        batch_rets = []
        batch_lens = []
        batch_log_prob = []

        # Reset episode-specific variables
        obs = self.env.reset()
        done = False
        ep_rews = []
        
        # Make two lists for recording the trajectory
        if record:
            batch_states = []
            batch_actions = []

        t = 1
        # Collect experience by acting in the environment with current policy
        while True:
            # Save trajectory
            if record:
                batch_states.append(obs)
            
            # Act in the environment  
            obs = env_wrapper(self.env_name, obs)
            
            # Normal behavior - game theory attacks are applied at gradient level only
            act, log_prob = self.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(device), sample=sample)
           
            obs, rew, done, info = self.env.step(act)
                        
            # Timestep
            t = t + 1
            
            # Save action_log_prob, reward
            batch_log_prob.append(log_prob)
            ep_rews.append(rew)
            
            # Save trajectory
            if record:
                batch_actions.append(act)

            if done or len(ep_rews) >= self.max_epi_len:
                # If episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)
                
                # The weight for each logprob(a_t|s_T) is sum_t^T (gamma^(t'-t) * r_t')
                returns = []
                R = 0
                
                # Normal reward processing
                for r in ep_rews[::-1]:
                    R = r + self.gamma * R
                    returns.insert(0, R)            
                returns = torch.tensor(returns, dtype=torch.float32)
                
                # Return whitening
                advantage = (returns - returns.mean()) / (returns.std() + 1e-20)
                batch_weights += advantage

                # End experience loop if we have enough of it
                if len(batch_lens) >= B:
                    break
                
                # Reset episode-specific variables
                obs, done, ep_rews, t = self.env.reset(), False, [], 1

        # Make torch tensor and restrict to batch_size
        weights = torch.as_tensor(batch_weights, dtype=torch.float32).to(device)
        logp = torch.stack(batch_log_prob)

        if record:
            return weights, logp, batch_rets, batch_lens, batch_states, batch_actions
        else:
            return weights, logp, batch_rets, batch_lens
    
    def train_one_epoch(self, B, device, sample):
        # Collect experience by acting in the environment with current policy
        weights, logp, batch_rets, batch_lens = self.collect_experience_for_training(
            B, device, sample=sample)
        
        # Calculate policy gradient loss
        batch_loss = -(logp * weights).mean()
    
        # Take a single policy gradient update step
        self.logits_net.zero_grad()
        batch_loss.backward()
        
        # Apply attack if this worker is strategically selected as attacker (and not defended)
        if (self.opts.enable_game_theory and self.is_currently_attacked and not self.is_currently_defended):
            # Apply random noise attack scaled by damage weight (higher value = stronger attack)
            grad = []
            for item in self.parameters():
                # Scale attack strength by damage weight: base_multiplier * (damage_weight / reference_weight)
                reference_weight = 10.0  # Reference damage weight for scaling
                attack_multiplier = opts.base_attack_intensity * (self.damage_weight / reference_weight)  # Original was * 3
                
                # Original FRL random noise attack with damage-weight scaling
                rnd = (torch.rand(item.grad.shape, device=item.device) * 2 - 1) * \
                      (item.grad.max().data - item.grad.min().data) * attack_multiplier
                grad.append(item.grad + rnd)
        else:
            # Return true gradient (normal behavior)
            grad = [item.grad for item in self.parameters()]
        
        # Report the results to the agent for training purpose
        return grad, batch_loss.item(), np.mean(batch_rets), np.mean(batch_lens)

    def to(self, device):
        self.logits_net.to(device)
        return self
    
    def eval(self):
        self.logits_net.eval()
        return self
        
    def train(self):
        self.logits_net.train()
        return self
    
    def parameters(self):
        return self.logits_net.parameters()

# =============================================================================
# AGENT AND MEMORY CLASSES WITH GAME THEORY
# =============================================================================

class Memory:
    def __init__(self):
        self.steps = {}
        self.eval_values = {}
        self.training_values = {}
        self.game_theory_metrics = {}

def worker_run(worker, param, opts, Batch_size, seed):
    # Distribute current parameters
    worker.load_param_from_master(param)
    worker.env.seed(seed)
    
    # Get returned gradients and info from all agents        
    out = worker.train_one_epoch(Batch_size, opts.device, opts.do_sample_for_training)
    
    # Store all values
    return out

class Agent:
    def __init__(self, opts):
        # Figure out the options
        self.opts = opts
        
        # Initialize game theory integrator
        self.game_integrator = GameTheoryIntegrator(opts)
        
        # Setup arrays for distributed RL
        self.world_size = opts.num_worker
        
        # Figure out the master
        self.master = Worker(
            id=0,
            env_name=opts.env_name,
            gamma=opts.gamma,
            hidden_units=opts.hidden_units, 
            activation=opts.activation, 
            output_activation=opts.output_activation,
            max_epi_len=opts.max_epi_len,
            opts=opts
        ).to(opts.device)
        
        # Figure out a copy of the master node for importance sampling purpose
        self.old_master = Worker(
            id=-1,
            env_name=opts.env_name,
            gamma=opts.gamma,
            hidden_units=opts.hidden_units, 
            activation=opts.activation, 
            output_activation=opts.output_activation,
            max_epi_len=opts.max_epi_len,
            opts=opts
        ).to(opts.device)
        
        # Figure out all the workers (all honest by default)
        self.workers = []
        for i in range(self.world_size):
            worker = Worker(
                id=i+1,
                env_name=opts.env_name,
                gamma=opts.gamma,
                hidden_units=opts.hidden_units, 
                activation=opts.activation, 
                output_activation=opts.output_activation,
                max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Set initial game theory status (all honest initially)
            if opts.enable_game_theory and hasattr(self.game_integrator, 'current_damage_weights'):
                worker.update_game_status(
                    damage_weight=self.game_integrator.current_damage_weights[i],
                    is_attacked=False,
                    is_defended=False,
                    attack_intensity=0.0  # No attacks by default
                )
            
            self.workers.append(worker)
        
        print(f'{opts.num_worker} workers initialized (all honest by default).')
        if opts.enable_game_theory:
            print(f'Game theory integration enabled with {opts.damage_scenario} scenario.')
        
        if not opts.eval_only:
            # Figure out the optimizer
            self.optimizer = optim.Adam(self.master.logits_net.parameters(), lr=opts.lr_model)
        
        self.pool = Pool(self.world_size)
        self.memory = Memory()
    
    def load(self, load_path):
        assert load_path is not None
        load_data = torch_load_cpu(load_path)
        # Load data for actor
        model_actor = get_inner_model(self.master.logits_net)
        model_actor.load_state_dict({**model_actor.state_dict(), **load_data.get('master', {})})
        
        if not self.opts.eval_only:
            # Load data for optimizer
            self.optimizer.load_state_dict(load_data['optimizer'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(self.opts.device)
            # Load data for torch and cuda
            torch.set_rng_state(load_data['rng_state'])
            if self.opts.use_cuda:
                torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
    
        print(' [*] Loading data from {}'.format(load_path))
        
    def save(self, epoch, run_id):
        print('Saving model and state...')
        torch.save(
            {
                'master': get_inner_model(self.master.logits_net).state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state_all(),
            },
            os.path.join(self.opts.save_dir, 'r{}-epoch-{}.pt'.format(run_id, epoch))
        )
    def save_performance_to_csv(self, strategy_name='default'):
        """Save current performance data to CSV"""
        csv_saver = PerformanceCSVSaver(self.opts.save_dir if self.opts.save_dir else './results')
        
        memory_data = {
            'steps': self.memory.steps,
            'training_values': self.memory.training_values,
            'eval_values': self.memory.eval_values
        }
        
        game_metrics = getattr(self.memory, 'game_theory_metrics', None)
        csv_saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)

    def eval(self):
        # Turn model to eval mode
        self.master.eval()
        
    def train(self):
        # Turn model to training mode
        self.master.train()
        
    def start_training(self, tb_logger=None, run_id=None):
        # Parameters of running
        opts = self.opts

        # For storing number of trajectories sampled
        step = 0
        epoch = 0
        ratios_step = 0
        
        # Store baseline performance for game theory damage calculation
        baseline_performance = 0
        
        # Start the training loop
        while step <= opts.max_trajectories:
            # Epoch for storing checkpoints of model
            epoch += 1
            
            # Turn model into training mode
            print('\n\n')
            print("|", format(f" Training step {step} run_id {run_id} in {opts.seeds}", "*^60"), "|")
            self.train()
            
            # Get game theory decisions
            attackers = set()
            defended_clients = set()
            if opts.enable_game_theory:
                attackers, defended_clients = self.game_integrator.get_attack_defense_decisions(step)
                print(f"Game Theory Round {step}: {len(attackers)} attackers, {len(defended_clients)} defended")
                
                # Update worker game status
                for i, worker in enumerate(self.workers):
                    is_attacked = i in attackers
                    is_defended = i in defended_clients
                    damage_weight = self.game_integrator.current_damage_weights[i]
                    
                    # Scale attack intensity by damage weight for game theory
                    scaled_attack_intensity = opts.base_attack_intensity * (damage_weight / 10.0)
                    
                    worker.update_game_status(
                        damage_weight=damage_weight,
                        is_attacked=is_attacked,
                        is_defended=is_defended,
                        attack_intensity=scaled_attack_intensity
                    )
            
            # Setup lr_scheduler
            print("Training with lr={:.3e}".format(self.optimizer.param_groups[0]['lr']), flush=True)
            
            # Some empty list for training and logging purpose
            gradient = []
            batch_loss = []
            batch_rets = []
            batch_lens = []
            
            # Distribute current params and Batch_Size to all workers
            param = get_inner_model(self.master.logits_net).state_dict()
            
            if opts.FedPG_BR:
                Batch_size = np.random.randint(opts.Bmin, opts.Bmax + 1)
            else:
                Batch_size = opts.B
        
            seeds = np.random.randint(1, 100000, self.world_size).tolist()
            args = zip(self.workers, repeat(param), repeat(opts), repeat(Batch_size), seeds)
            
            results = self.pool.starmap(worker_run, args)

            # Collect the gradient(for training), loss(for logging only), returns(for logging only), and epi_length(for logging only) from workers         
            for out in tqdm(results, desc='Worker node'):
                grad, loss, rets, lens = out
                
                # Store all values
                gradient.append(grad)
                batch_loss.append(loss)
                batch_rets.append(rets)
                batch_lens.append(lens)
            
            # Simulate FedScsPG-attack (if needed) on server for demo
            if opts.attack_type == 'FedScsPG-attack' and opts.num_Byzantine > 0:  
                for idx, _ in enumerate(self.master.parameters()):
                    tmp = []
                    for bad_worker in range(opts.num_Byzantine):
                        tmp.append(gradient[bad_worker][idx].view(-1))
                    tmp = torch.stack(tmp)

                    estimated_2sigma = euclidean_dist(tmp, tmp).max()
                    estimated_mean = tmp.mean(0)
                    
                    # Change the gradient to be estimated_mean + 3sigma (with a random direction rnd)
                    rnd = torch.rand(gradient[0][idx].shape) * 2. - 1.
                    rnd = rnd / rnd.norm()
                    attacked_gradient = estimated_mean.view(gradient[bad_worker][idx].shape) + rnd * estimated_2sigma * 3. / 2.
                    for bad_worker in range(opts.num_Byzantine):
                        gradient[bad_worker][idx] = attacked_gradient
                        
            # Make the old policy as a copy of the current master node
            self.old_master.load_param_from_master(param)
            
            # Do Aggregate Algorithm to detect Byzantine worker on master node
            if opts.FedPG_BR:
                
                # Flatten the gradient vectors of each worker and put them together, shape [num_worker, -1]
                mu_vec = None
                for idx, item in enumerate(self.old_master.parameters()):
                    # Stack gradient[idx] from all worker nodes
                    grad_item = []
                    for i in range(self.world_size):
                         grad_item.append(gradient[i][idx])
                    grad_item = torch.stack(grad_item).view(self.world_size, -1)
                
                    # Concat stacked grad vector
                    if mu_vec is None:
                        mu_vec = grad_item.clone()
                    else:
                        mu_vec = torch.cat((mu_vec, grad_item.clone()), -1)
                    
                # Calculate the norm distance between each worker's gradient vector, shape [num_worker, num_worker]
                dist = euclidean_dist(mu_vec, mu_vec)
                
                # Calculate C, Variance Bound V, threshold, and alpha
                V = 2 * np.log(2 * opts.num_worker / opts.delta)
                sigma = opts.sigma

                threshold = 2 * sigma * np.sqrt(V / Batch_size)
                alpha = opts.alpha
                
                # To find MOM: |dist <= threshold| > 0.5 * num_worker
                mu_med_vec = None
                k_prime = (dist <= threshold).sum(-1) > (0.5 * self.world_size)
                
                # Computes the mom of the gradients, mu_med_vec, and
                # filter the gradients it believes to be Byzantine and store the index of non-Byzantine gradients in Good_set
                if torch.sum(k_prime) > 0:
                    mu_mean_vec = torch.mean(mu_vec[k_prime], 0).view(1, -1)
                    mu_med_vec = mu_vec[k_prime][euclidean_dist(mu_mean_vec, mu_vec[k_prime]).argmin()].view(1, -1)
                    # Applying R1 to filter
                    Good_set = euclidean_dist(mu_vec, mu_med_vec) <= 1 * threshold
                else:
                    Good_set = k_prime  # skip this step if k_prime is empty (i.e., all False)
                
                # Avoid the scenarios that Good_set is empty or can have |Gt| < (1 − α)K.
                if torch.sum(Good_set) < (1 - alpha) * self.world_size or torch.sum(Good_set) == 0:
                    
                    # Re-calculate mom of the gradients
                    k_prime = (dist <= 2 * sigma).sum(-1) > (0.5 * self.world_size)
                    if torch.sum(k_prime) > 0:
                        mu_mean_vec = torch.mean(mu_vec[k_prime], 0).view(1, -1)
                        mu_med_vec = mu_vec[k_prime][euclidean_dist(mu_mean_vec, mu_vec[k_prime]).argmin()].view(1, -1)
                        # Re-filter with R2
                        Good_set = euclidean_dist(mu_vec, mu_med_vec) <= 2 * sigma
                    else:
                        Good_set = torch.zeros(self.world_size, 1).to(opts.device).bool()
            
            # Else will treat all nodes as non-Byzantine nodes
            else:
                Good_set = torch.ones(self.world_size, 1).to(opts.device).bool()
            
            # Calculate number of good gradients for logging
            N_good = torch.sum(Good_set)
            
            # Aggregate all detected non-Byzantine gradients to get mu
            if N_good > 0:
                mu = []
                for idx, item in enumerate(self.old_master.parameters()):
                    grad_item = []
                    for i in range(self.world_size):
                        if Good_set[i]:  # only aggregate non-Byzantine gradients
                            grad_item.append(gradient[i][idx])
                    mu.append(torch.stack(grad_item).mean(0))
            else:  # if still all nodes are detected to be Byzantine, check the sigma. If sigma is set properly, this situation will not happen.
                mu = None
            
            # Perform gradient update in master node
            grad_array = []  # store gradients for logging

            if opts.FedPG_BR or opts.SVRPG:
                
                if opts.FedPG_BR:
                    # For n=1 to Nt ~ Geom(B/B+b) do grad update
                    b = opts.b
                    N_t = np.random.geometric(p=1 - Batch_size/(Batch_size + b))
                    
                elif opts.SVRPG:
                    b = opts.b
                    N_t = opts.N
                    
                for n in tqdm(range(N_t), desc='Master node'):
                   
                    # Calculate new gradient in master node
                    self.optimizer.zero_grad()

                    # Sample b trajectory using the latest policy (\theta_n) of master node
                    weights, new_logp, batch_rets, batch_lens, batch_states, batch_actions = self.master.collect_experience_for_training(
                        b, opts.device, record=True, sample=opts.do_sample_for_training)
                        
                    # Calculate gradient for the new policy (\theta_n)
                    loss_new = -(new_logp * weights).mean()
                    self.master.logits_net.zero_grad()
                    loss_new.backward()
                    
                    if mu:
                        # Get the old log_p with the old policy (\theta_0) but fixing the actions to be the same as the sampled trajectory
                        old_logp = []
                        for idx, obs in enumerate(batch_states):
                            # Act in the environment with the fixed action
                            obs = env_wrapper(opts.env_name, obs)
                            _, old_log_prob = self.old_master.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(opts.device), 
                                                                         fixed_action=batch_actions[idx])
                            # Store in the old_logp
                            old_logp.append(old_log_prob)
                        old_logp = torch.stack(old_logp)
                        
                        # Finding the ratio (pi_theta / pi_theta__old):
                        ratios = torch.exp(old_logp.detach() - new_logp.detach())
                        ratios_step += 1
                        
                        # Calculate gradient for the old policy (\theta_0)
                        loss_old = -(old_logp * weights * ratios).mean()
                        self.old_master.logits_net.zero_grad()
                        loss_old.backward()
                        grad_old = [item.grad for item in self.old_master.parameters()]   
                    
                        # Early stop if ratio is not within [0.995, 1.005]
                        if torch.abs(ratios.mean()) < 0.995 or torch.abs(ratios.mean()) > 1.005:
                            N_t = n
                            break
                        
                        if tb_logger is not None:
                            tb_logger.add_scalar(f'params/ratios_{run_id}', ratios.mean(), ratios_step)
                        
                        # Adjust and set the gradient for latest policy (\theta_n)
                        for idx, item in enumerate(self.master.parameters()):
                            item.grad = item.grad - grad_old[idx] + mu[idx]  # if mu is None, use grad from master 
                            grad_array += (item.grad.data.view(-1).cpu().tolist())
                        
                    # Take a gradient step
                    self.optimizer.step()
        
            else:  # GOMDP in this case
                
                b = 0
                N_t = 0
                
                # Perform gradient descent with mu vector
                for idx, item in enumerate(self.master.parameters()):
                    item.grad = mu[idx]
                    grad_array += (item.grad.data.view(-1).cpu().tolist())
                    
                # Take a gradient step
                self.optimizer.step()  
            
            # Calculate current performance for game theory metrics
            current_performance = np.mean(batch_rets)
            if step == 0:
                baseline_performance = current_performance
            
            # Calculate game theory metrics
            game_metrics = {}
            if opts.enable_game_theory:
                game_metrics = self.game_integrator.calculate_damage_metrics(
                    step, attackers, defended_clients, baseline_performance, current_performance
                )
                
            print('\nepoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f \t N_good: %d' %
                (epoch, np.mean(batch_loss), np.mean(batch_rets), np.mean(batch_lens), N_good))
            
            if opts.enable_game_theory and game_metrics:
                print(f'Game Theory: attackers: {game_metrics["num_attackers"]} \t defended: {game_metrics["num_defended"]} \t damage: {game_metrics["performance_drop"]:.3f} \t strategy: {game_metrics["current_strategy"]}')
            
            # Current step: number of trajectories sampled
            step += round((Batch_size * self.world_size + b * N_t) / (1 + self.world_size)) if self.world_size > 1 else Batch_size + b * N_t
            
            # Logging to tensorboard
            if tb_logger is not None:
                
                # Training log
                tb_logger.add_scalar(f'train/total_rewards_{run_id}', np.mean(batch_rets), step)
                tb_logger.add_scalar(f'train/epi_length_{run_id}', np.mean(batch_lens), step)
                tb_logger.add_scalar(f'train/loss_{run_id}', np.mean(batch_loss), step)
                # Grad log
                tb_logger.add_scalar(f'grad/grad_{run_id}', np.mean(grad_array), step)
                # Optimizer log
                tb_logger.add_scalar(f'params/lr_{run_id}', self.optimizer.param_groups[0]['lr'], step)
                tb_logger.add_scalar(f'params/N_t_{run_id}', N_t, step)

                # Game theory logging
                if opts.enable_game_theory and game_metrics:
                    tb_logger.add_scalar(f'game/num_attackers_{run_id}', game_metrics['num_attackers'], step)
                    tb_logger.add_scalar(f'game/num_defended_{run_id}', game_metrics['num_defended'], step)
                    tb_logger.add_scalar(f'game/performance_drop_{run_id}', game_metrics['performance_drop'], step)
                    tb_logger.add_scalar(f'game/theoretical_damage_{run_id}', game_metrics['theoretical_damage'], step)
                    tb_logger.add_scalar(f'game/successful_attacks_{run_id}', game_metrics['successful_attacks'], step)
                    tb_logger.add_scalar(f'game/defense_budget_util_{run_id}', game_metrics['defense_budget_utilization'], step)
                    tb_logger.add_scalar(f'game/attack_budget_util_{run_id}', game_metrics['attack_budget_utilization'], step)
                    tb_logger.add_scalar(f'game/avg_attacker_utility_{run_id}', game_metrics['avg_attacker_utility'], step)
                    tb_logger.add_scalar(f'game/current_era_{run_id}', game_metrics['current_era'], step)
                    tb_logger.add_scalar(f'game/deterrence_efficiency_{run_id}', game_metrics['deterrence_efficiency'], step)

                # Byzantine filtering log (for FedPG_BR algorithm)
                if opts.FedPG_BR:
                    # No predetermined Byzantine workers, but log filtering results
                    y_pred = (~ Good_set).view(-1).cpu().tolist()
                    
                    tb_logger.add_scalar(f'Byzantine/threshold_{run_id}', threshold, step)
                    tb_logger.add_scalar(f'grad_norm_mean/ALL_{run_id}', torch.mean(dist), step)
                    tb_logger.add_scalar(f'grad_norm_max/ALL_{run_id}', torch.max(dist), step)
                    tb_logger.add_scalar(f'Byzantine/N_good_pred_{run_id}', N_good, step)
                        
                # For performance plot
                if run_id not in self.memory.steps.keys():
                    self.memory.steps[run_id] = []
                    self.memory.eval_values[run_id] = []
                    self.memory.training_values[run_id] = []
                    if opts.enable_game_theory:
                        self.memory.game_theory_metrics[run_id] = []
                
                self.memory.steps[run_id].append(step)
                self.memory.training_values[run_id].append(np.mean(batch_rets))
                if opts.enable_game_theory and game_metrics:
                    self.memory.game_theory_metrics[run_id].append(game_metrics)
                             
            # Do validating
            eval_reward = self.start_validating(tb_logger, step, max_steps=opts.val_max_steps, 
                                              render=opts.render, run_id=run_id)
            if tb_logger is not None:
                 self.memory.eval_values[run_id].append(eval_reward)
                            
            # Save current model
            if not opts.no_saving:
                self.save(epoch, run_id)
                
    # Validate the new model   
    def start_validating(self, tb_logger=None, id=0, max_steps=1000, render=False, run_id=0, mode='human'):
        print('\nValidating...', flush=True)
        
        val_ret = 0.0
        val_len = 0.0
        
        for _ in range(self.opts.val_size):
            epi_ret, epi_len, _ = self.master.rollout(self.opts.device, max_steps=max_steps, render=render, 
                                                    sample=False, mode=mode, save_dir='./outputs/', 
                                                    filename=f'gym_{run_id}_{_}.gif')
            val_ret += epi_ret
            val_len += epi_len
        
        val_ret /= self.opts.val_size
        val_len /= self.opts.val_size
        
        print('\nGradient step: %3d \t return: %.3f \t ep_len: %.3f' %
                (id, np.mean(val_ret), np.mean(val_len)))
        
        if tb_logger is not None:
            tb_logger.add_scalar(f'validate/total_rewards_{run_id}', np.mean(val_ret), id)
            tb_logger.add_scalar(f'validate/epi_length_{run_id}', np.mean(val_len), id)
            tb_logger.close()
        
        return np.mean(val_ret)
    
    def plot_graph(self, array):
        plt.ioff()
        fig = plt.figure(figsize=(8, 4))
        y = []
        
        for id in self.memory.steps.keys():
             x = self.memory.steps[id]
             y.append(Rbf(x, array[id], function='linear')(np.arange(self.opts.max_trajectories)))
        
        mean = np.mean(y, axis=0)
        
        l, h = st.norm.interval(0.90, loc=np.mean(y, axis=0), scale=st.sem(y, axis=0))
        
        plt.plot(mean)
        plt.fill_between(range(int(self.opts.max_trajectories)), l, h, alpha=0.5)
        
        axes = plt.axes()
        axes.set_ylim([self.opts.min_reward, self.opts.max_reward])
        
        plt.xlabel("Number of Trajectories")
        plt.ylabel("Reward")
        plt.grid(True)
        plt.tight_layout()
        return fig
    
    def plot_game_theory_metrics(self):
        """Plot game theory specific metrics"""
        if not self.opts.enable_game_theory or not self.memory.game_theory_metrics:
            return None
            
        plt.ioff()
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Extract metrics for plotting
        for run_id in self.memory.game_theory_metrics.keys():
            steps = self.memory.steps[run_id]
            metrics_list = self.memory.game_theory_metrics[run_id]
            
            if not metrics_list:
                continue
            
            # Extract individual metrics
            performance_drops = [m['performance_drop'] for m in metrics_list]
            theoretical_damages = [m['theoretical_damage'] for m in metrics_list]
            num_attackers = [m['num_attackers'] for m in metrics_list]
            num_defended = [m['num_defended'] for m in metrics_list]
            defense_budget_util = [m['defense_budget_utilization'] for m in metrics_list]
            attack_budget_util = [m['attack_budget_utilization'] for m in metrics_list]
            
            # Plot metrics
            axes[0, 0].plot(steps, performance_drops, alpha=0.7, label=f'Run {run_id}')
            axes[0, 1].plot(steps, theoretical_damages, alpha=0.7, label=f'Run {run_id}')
            axes[0, 2].plot(steps, num_attackers, alpha=0.7, label=f'Run {run_id}')
            axes[1, 0].plot(steps, num_defended, alpha=0.7, label=f'Run {run_id}')
            axes[1, 1].plot(steps, defense_budget_util, alpha=0.7, label=f'Run {run_id}')
            axes[1, 2].plot(steps, attack_budget_util, alpha=0.7, label=f'Run {run_id}')
        
        # Set titles and labels
        axes[0, 0].set_title('Performance Drop')
        axes[0, 0].set_ylabel('Performance Drop')
        axes[0, 0].grid(True)
        axes[0, 0].legend()
        
        axes[0, 1].set_title('Theoretical Damage')
        axes[0, 1].set_ylabel('Theoretical Damage')
        axes[0, 1].grid(True)
        axes[0, 1].legend()
        
        axes[0, 2].set_title('Number of Attackers')
        axes[0, 2].set_ylabel('Number of Attackers')
        axes[0, 2].grid(True)
        axes[0, 2].legend()
        
        axes[1, 0].set_title('Number of Defended Clients')
        axes[1, 0].set_ylabel('Number Defended')
        axes[1, 0].grid(True)
        axes[1, 0].legend()
        
        axes[1, 1].set_title('Defense Budget Utilization')
        axes[1, 1].set_ylabel('Budget Utilization')
        axes[1, 1].grid(True)
        axes[1, 1].legend()
        
        axes[1, 2].set_title('Attack Budget Utilization')
        axes[1, 2].set_ylabel('Budget Utilization')
        axes[1, 2].grid(True)
        axes[1, 2].legend()
        
        for ax in axes.flat:
            ax.set_xlabel('Training Steps')
        
        plt.tight_layout()
        return fig
    
    def log_performance(self, tb_logger):
        eval_img = self.plot_graph(self.memory.eval_values)
        training_img = self.plot_graph(self.memory.training_values)
        tb_logger.add_figure(f'validate/performance_until_{len(self.memory.steps.keys())}_runs', 
                           eval_img, len(self.memory.steps.keys()))
        tb_logger.add_figure(f'train/performance_until_{len(self.memory.steps.keys())}_runs', 
                           training_img, len(self.memory.steps.keys()))
        
        # Log game theory metrics
        if self.opts.enable_game_theory:
            game_theory_img = self.plot_game_theory_metrics()
            if game_theory_img:
                tb_logger.add_figure(f'game/metrics_until_{len(self.memory.steps.keys())}_runs', 
                                   game_theory_img, len(self.memory.steps.keys()))
    
    def save_training_plots(self, save_dir=None, strategy_results=None):
        """Save training and validation plots as image files"""
        if not self.memory.steps:
            print("No training data to plot")
            return
            
        if save_dir is None:
            save_dir = self.opts.save_dir if self.opts.save_dir else './plots'
        
        os.makedirs(save_dir, exist_ok=True)
        
        if strategy_results is not None:
            # Plot comparison of all strategies
            self.plot_strategy_comparison(strategy_results, save_dir)
        else:
            # Plot individual run results
            plt.figure(figsize=(12, 5))
            
            # Training subplot
            plt.subplot(1, 2, 1)
            self.plot_returns(self.memory.training_values, 'Training Returns', 'Training')
            
            # Validation subplot  
            plt.subplot(1, 2, 2)
            self.plot_returns(self.memory.eval_values, 'Validation Returns', 'Validation')
            
            plt.tight_layout()
            plot_path = os.path.join(save_dir, 'training_curves.png')
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"Training plots saved to: {plot_path}")
            plt.show()  # Display the plot
            plt.close()
            
            # Save game theory plots
            if self.opts.enable_game_theory:
                game_theory_fig = self.plot_game_theory_metrics()
                if game_theory_fig:
                    game_theory_path = os.path.join(save_dir, 'game_theory_metrics.png')
                    game_theory_fig.savefig(game_theory_path, dpi=300, bbox_inches='tight')
                    print(f"Game theory plots saved to: {game_theory_path}")
                    plt.close(game_theory_fig)
        
        # Also create individual plots
        self.save_individual_plots(save_dir)
    
    def plot_strategy_comparison(self, strategy_results, save_dir):
        """Plot comparison of all defense strategies"""
        plt.figure(figsize=(15, 10))
        
        # Plot 1: Training Returns Comparison
        plt.subplot(2, 3, 1)
        for strategy_name, results in strategy_results.items():
            steps = results['steps']
            training_values = results['training_values']
            plt.plot(steps, training_values, label=strategy_name, linewidth=2, alpha=0.8)
        
        plt.title('Training Returns Comparison')
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Validation Returns Comparison
        plt.subplot(2, 3, 2)
        for strategy_name, results in strategy_results.items():
            steps = results['steps']
            eval_values = results['eval_values']
            plt.plot(steps, eval_values, label=strategy_name, linewidth=2, alpha=0.8)
        
        plt.title('Validation Returns Comparison')
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Final Performance Bar Chart
        plt.subplot(2, 3, 3)
        final_training = [results['training_values'][-1] for results in strategy_results.values()]
        final_validation = [results['eval_values'][-1] for results in strategy_results.values()]
        
        x = np.arange(len(strategy_results))
        width = 0.35
        
        plt.bar(x - width/2, final_training, width, label='Training', alpha=0.8)
        plt.bar(x + width/2, final_validation, width, label='Validation', alpha=0.8)
        
        plt.title('Final Performance Comparison')
        plt.xlabel('Strategy')
        plt.ylabel('Episode Return')
        plt.xticks(x, list(strategy_results.keys()), rotation=45)
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Game Theory Specific Plots (if available)
        game_theory_strategies = {k: v for k, v in strategy_results.items() 
                                if k != 'baseline' and 'game_theory_metrics' in v}
        
        if game_theory_strategies:
            # Plot 4: Damage Comparison
            plt.subplot(2, 3, 4)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    damage_values = [m['performance_drop'] for m in results['game_theory_metrics']]
                    plt.plot(steps, damage_values, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Performance Drop Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Performance Drop')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Plot 5: Attack Success Rate
            plt.subplot(2, 3, 5)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    attack_success = [m['successful_attacks'] for m in results['game_theory_metrics']]
                    plt.plot(steps, attack_success, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Successful Attacks Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Number of Successful Attacks')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Plot 6: Defense Efficiency
            plt.subplot(2, 3, 6)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    defense_efficiency = [m['deterrence_efficiency'] for m in results['game_theory_metrics']]
                    plt.plot(steps, defense_efficiency, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Defense Efficiency Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Deterrence Efficiency')
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        comparison_path = os.path.join(save_dir, 'strategy_comparison.png')
        plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
        print(f"Strategy comparison plots saved to: {comparison_path}")
        plt.show()
        plt.close()
        
        # Summary statistics table
        self.save_strategy_summary_table(strategy_results, save_dir)
    
    def plot_returns(self, data_dict, title, label):
        """Plot returns for multiple runs with confidence intervals"""
        if not data_dict:
            plt.title(f"{title} - No Data")
            return
            
        all_returns = []
        all_steps = []
        
        # Collect data from all runs
        for run_id in data_dict.keys():
            steps = self.memory.steps[run_id]
            returns = data_dict[run_id]
            
            # Plot individual run
            plt.plot(steps, returns, alpha=0.3, linewidth=1)
            all_steps.append(steps)
            all_returns.append(returns)
        
        # Calculate mean and confidence interval if multiple runs
        if len(all_returns) > 1:
            # Interpolate all runs to common step grid
            max_steps = max(max(steps) for steps in all_steps)
            common_steps = np.linspace(0, max_steps, 100)
            
            interpolated_returns = []
            for i, (steps, returns) in enumerate(zip(all_steps, all_returns)):
                if len(steps) > 1 and len(returns) > 1:
                    interp_returns = np.interp(common_steps, steps, returns)
                    interpolated_returns.append(interp_returns)
            
            if interpolated_returns:
                mean_returns = np.mean(interpolated_returns, axis=0)
                std_returns = np.std(interpolated_returns, axis=0)
                
                # Plot mean
                plt.plot(common_steps, mean_returns, 'b-', linewidth=2, label=f'Mean ({len(all_returns)} runs)')
                
                # Plot confidence interval
                plt.fill_between(common_steps, 
                               mean_returns - std_returns, 
                               mean_returns + std_returns, 
                               alpha=0.2, color='blue', label='±1 std')
                plt.legend()
        
        plt.title(title)
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.grid(True, alpha=0.3)
        
        # Set y-axis limits based on environment
        if hasattr(self.opts, 'min_reward') and hasattr(self.opts, 'max_reward'):
            plt.ylim(self.opts.min_reward, self.opts.max_reward)
    
    def save_individual_plots(self, save_dir):
        """Save separate training and validation plots"""
        # Training plot
        plt.figure(figsize=(10, 6))
        self.plot_returns(self.memory.training_values, 'Training Returns vs Steps', 'Training')
        plt.savefig(os.path.join(save_dir, 'training_returns.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # Validation plot
        plt.figure(figsize=(10, 6))
        self.plot_returns(self.memory.eval_values, 'Validation Returns vs Steps', 'Validation')
        plt.savefig(os.path.join(save_dir, 'validation_returns.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
    def save_strategy_summary_table(self, strategy_results, save_dir):
        """Save summary table comparing all strategies"""
        import pandas as pd
        
        summary_data = []
        for strategy_name, results in strategy_results.items():
            final_training = results['training_values'][-1]
            final_validation = results['eval_values'][-1]
            
            row = {
                'Strategy': strategy_name,
                'Final_Training_Return': final_training,
                'Final_Validation_Return': final_validation,
                'Training_Improvement': results['training_values'][-1] - results['training_values'][0],
                'Validation_Improvement': results['eval_values'][-1] - results['eval_values'][0]
            }
            
            # Add game theory metrics if available
            if 'game_theory_metrics' in results and results['game_theory_metrics']:
                metrics = results['game_theory_metrics']
                row.update({
                    'Total_Performance_Drop': sum(m['performance_drop'] for m in metrics),
                    'Avg_Successful_Attacks': np.mean([m['successful_attacks'] for m in metrics]),
                    'Avg_Defense_Efficiency': np.mean([m['deterrence_efficiency'] for m in metrics]),
                    'Total_Theoretical_Damage': sum(m['theoretical_damage'] for m in metrics)
                })
            
            summary_data.append(row)
        
        df = pd.DataFrame(summary_data)
        csv_path = os.path.join(save_dir, 'strategy_comparison_summary.csv')
        df.to_csv(csv_path, index=False)
        print(f"Strategy comparison summary saved to: {csv_path}")
        
        print(f"Individual plots saved to: {save_dir}/training_returns.png and validation_returns.png")

def run_strategy_comparison(opts):
    """Run all strategies and compare results"""
    import pprint
    pprint.pprint(vars(opts))
    
    all_strategy_results = {}
    
    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"RUNNING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")
        
        # Configure for this strategy
        if strategy_name == 'baseline':
            opts.enable_game_theory = False
        else:
            opts.enable_game_theory = True
        
        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Setup save directory for this strategy
        if not opts.no_saving:
            strategy_save_dir = os.path.join(opts.save_dir, strategy_name)
            os.makedirs(strategy_save_dir, exist_ok=True)
            
            # Save arguments for this strategy
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            opts_dict['current_strategy'] = strategy_name
            with open(os.path.join(strategy_save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = strategy_name
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
            
        all_strategy_results[strategy_name] = strategy_results
        
        print(f"\nStrategy {strategy_name.upper()} completed!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")
    
    # Generate comparative plots
    print(f"\n{'='*60}")
    print(f"GENERATING COMPARATIVE ANALYSIS")
    print(f"{'='*60}")
    
    # Use first agent for plotting (all have same plotting capabilities)
    agent.save_training_plots(opts.save_dir, all_strategy_results)
    
    # Print final comparison
    print(f"\n{'='*80}")
    print(f"FINAL STRATEGY COMPARISON RESULTS")
    print(f"{'='*80}")
    
    for strategy_name, results in all_strategy_results.items():
        if results['training_values'] and results['eval_values']:
            print(f"{strategy_name:15s}: Train={results['training_values'][-1]:6.2f}, "
                  f"Val={results['eval_values'][-1]:6.2f}, "
                  f"Improvement={results['training_values'][-1] - results['training_values'][0]:+6.2f}")
    
    return all_strategy_results

    def save_game_theory_results(self, save_dir=None):
        """Save detailed game theory results to CSV"""
        if not self.opts.enable_game_theory or not self.memory.game_theory_metrics:
            return
            
        if save_dir is None:
            save_dir = self.opts.save_dir if self.opts.save_dir else './results'
        
        os.makedirs(save_dir, exist_ok=True)
        
        for run_id in self.memory.game_theory_metrics.keys():
            csv_filename = os.path.join(save_dir, f'game_theory_run_{run_id}.csv')
            
            with open(csv_filename, mode="w", newline="") as csv_file:
                writer = csv.writer(csv_file)
                writer.writerow([
                    "step", "era", "era_round", "performance_drop", "theoretical_damage",
                    "num_attackers", "num_defended", "successful_attacks", "blocked_attacks",
                    "avg_attacker_utility", "defense_cost_used", "attack_cost_used",
                    "defense_budget_util", "attack_budget_util", "deterrence_efficiency",
                    "current_strategy", "damage_weight_range"
                ])
                
                steps = self.memory.steps[run_id]
                metrics_list = self.memory.game_theory_metrics[run_id]
                
                for i, step in enumerate(steps):
                    if i < len(metrics_list):
                        m = metrics_list[i]
                        writer.writerow([
                            step, m['current_era'], m['era_round'], m['performance_drop'],
                            m['theoretical_damage'], m['num_attackers'], m['num_defended'],
                            m['successful_attacks'], m['blocked_attacks'], m['avg_attacker_utility'],
                            m['defense_cost_used'], m['attack_cost_used'], m['defense_budget_utilization'],
                            m['attack_budget_utilization'], m['deterrence_efficiency'],
                            m['current_strategy'], m['damage_weight_range']
                        ])
            
            print(f"Game theory results for run {run_id} saved to: {csv_filename}")

# =============================================================================
# MAIN EXECUTION FUNCTION
# =============================================================================

def run(opts):
    """Main execution function with optional strategy comparison"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies and compare
        return run_strategy_comparison(opts)
    else:
        # Run single strategy (original behavior)
        import pprint
        pprint.pprint(vars(opts))
        
        # Setup tensorboard
        if not opts.no_tb:
            tb_writer = SummaryWriter(opts.log_dir)
        else:
            tb_writer = None

        # Optionally configure tensorboard
        if not opts.no_saving and not os.path.exists(opts.save_dir):
            os.makedirs(opts.save_dir)

        # Save arguments so exact configuration can always be found
        if not opts.no_saving:
            # Create a copy of opts without non-serializable objects
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])  # Convert device to string
            with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Figure out the RL
        agent = Agent(opts)
        
        # Do validation only
        if opts.eval_only:
            # Set the random seed
            torch.manual_seed(opts.seed)
            np.random.seed(opts.seed)
            
            # Load data from load_path
            if opts.load_path is not None:
                agent.load(opts.load_path)
            
            agent.start_validating(tb_writer, 0, opts.val_max_steps, opts.render, mode=opts.mode)
            
        else:
            for run_id in opts.seeds:
                # Set the random seed
                torch.manual_seed(run_id)
                np.random.seed(run_id)
                
                nn_parms_worker = Worker(
                    id=0,
                    is_Byzantine=False,
                    env_name=opts.env_name,
                    gamma=opts.gamma,
                    hidden_units=opts.hidden_units, 
                    activation=opts.activation, 
                    output_activation=opts.output_activation,
                    max_epi_len=opts.max_epi_len,
                    opts=opts
                ).to(opts.device)
                
                # Load data from random policy
                model_actor = get_inner_model(agent.master.logits_net)
                model_actor.load_state_dict({**model_actor.state_dict(), 
                                           **get_inner_model(nn_parms_worker.logits_net).state_dict()})
            
                # Start training here
                agent.start_training(tb_writer, run_id)
                if tb_writer:
                    agent.log_performance(tb_writer)
            
            # Save training plots after all runs complete
            print("\n" + "="*60)
            print("Training completed! Generating plots...")
            agent.save_training_plots()
            
            # Save game theory results
            if opts.enable_game_theory:
                agent.save_game_theory_results()
                print("Game theory results saved!")
            
            print("="*60)

# =============================================================================
# MAIN EXECUTION
# =============================================================================
import os
import json
import pickle
import csv
import numpy as np
from typing import Dict, Any, Optional

class StrategyCheckpointManager:
    """Manages saving/loading of strategy results to avoid re-running completed experiments"""
    
    def __init__(self, base_dir: str, experiment_id: str = None):
        """
        Args:
            base_dir: Base directory for saving results
            experiment_id: Unique identifier for this experiment run
        """
        if experiment_id is None:
            experiment_id = f"exp_{int(time.time())}"
        
        self.checkpoint_dir = os.path.join(base_dir, "checkpoints", experiment_id)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        print(f"Checkpoint system initialized: {self.checkpoint_dir}")
    
    def get_strategy_checkpoint_path(self, strategy_name: str) -> str:
        """Get the checkpoint file path for a strategy"""
        return os.path.join(self.checkpoint_dir, f"{strategy_name}_checkpoint.pkl")
    
    def get_strategy_csv_path(self, strategy_name: str) -> str:
        """Get the CSV file path for a strategy"""
        return os.path.join(self.checkpoint_dir, f"{strategy_name}_results.csv")
    
    def strategy_exists(self, strategy_name: str) -> bool:
        """Check if strategy results already exist"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        csv_path = self.get_strategy_csv_path(strategy_name)
        
        exists = os.path.exists(checkpoint_path) and os.path.exists(csv_path)
        if exists:
            print(f"Found existing results for strategy: {strategy_name}")
        return exists
    
    def save_strategy_results(self, strategy_name: str, results: Dict[str, Any], 
                            config_dict: Dict[str, Any] = None):
        """Save strategy results to checkpoint files"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        csv_path = self.get_strategy_csv_path(strategy_name)
        
        # Save binary checkpoint with all data
        checkpoint_data = {
            'strategy_name': strategy_name,
            'results': results,
            'config': config_dict,
            'timestamp': time.time()
        }
        
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint_data, f)
        
        # Save human-readable CSV
        self._save_results_csv(csv_path, strategy_name, results)
        
        print(f"Saved results for strategy: {strategy_name}")
        print(f"  Checkpoint: {checkpoint_path}")
        print(f"  CSV: {csv_path}")
    
    def load_strategy_results(self, strategy_name: str) -> Optional[Dict[str, Any]]:
        """Load strategy results from checkpoint"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        
        if not os.path.exists(checkpoint_path):
            return None
        
        try:
            with open(checkpoint_path, 'rb') as f:
                checkpoint_data = pickle.load(f)
            
            print(f"Loaded existing results for strategy: {strategy_name}")
            return checkpoint_data['results']
        
        except Exception as e:
            print(f"Error loading checkpoint for {strategy_name}: {e}")
            return None
    
    def _save_results_csv(self, csv_path: str, strategy_name: str, results: Dict[str, Any]):
        """Save results in CSV format for easy inspection"""
        try:
            with open(csv_path, 'w', newline='') as f:
                writer = csv.writer(f)
                
                # Write header
                header = ['step', 'training_return', 'validation_return']
                if 'game_theory_metrics' in results:
                    header.extend(['num_attackers', 'num_defended', 'performance_drop', 
                                 'theoretical_damage', 'defense_budget_util', 'attack_budget_util'])
                
                writer.writerow(header)
                
                # Write data rows
                steps = results.get('steps', [])
                training_values = results.get('training_values', [])
                eval_values = results.get('eval_values', [])
                game_metrics = results.get('game_theory_metrics', [])
                
                max_len = max(len(steps), len(training_values), len(eval_values))
                
                for i in range(max_len):
                    row = [
                        steps[i] if i < len(steps) else '',
                        training_values[i] if i < len(training_values) else '',
                        eval_values[i] if i < len(eval_values) else ''
                    ]
                    
                    if game_metrics and i < len(game_metrics):
                        metrics = game_metrics[i]
                        row.extend([
                            metrics.get('num_attackers', ''),
                            metrics.get('num_defended', ''),
                            metrics.get('performance_drop', ''),
                            metrics.get('theoretical_damage', ''),
                            metrics.get('defense_budget_utilization', ''),
                            metrics.get('attack_budget_utilization', '')
                        ])
                    elif 'game_theory_metrics' in results:
                        row.extend(['', '', '', '', '', ''])
                    
                    writer.writerow(row)
                    
        except Exception as e:
            print(f"Warning: Could not save CSV for {strategy_name}: {e}")
    
    def list_completed_strategies(self) -> list:
        """List all strategies that have been completed"""
        completed = []
        if os.path.exists(self.checkpoint_dir):
            for filename in os.listdir(self.checkpoint_dir):
                if filename.endswith('_checkpoint.pkl'):
                    strategy_name = filename.replace('_checkpoint.pkl', '')
                    completed.append(strategy_name)
        return completed
    
    def get_experiment_summary(self) -> Dict[str, Any]:
        """Get summary of all completed strategies in this experiment"""
        summary = {
            'experiment_dir': self.checkpoint_dir,
            'completed_strategies': self.list_completed_strategies(),
            'total_completed': len(self.list_completed_strategies())
        }
        return summary


def run_strategy_comparison_with_checkpoints(opts):
    """Enhanced strategy comparison with checkpoint support"""
    import pprint
    pprint.pprint(vars(opts))
    
    # Initialize checkpoint manager
    checkpoint_manager = StrategyCheckpointManager(
        base_dir=opts.save_dir if opts.save_dir else './results',
        experiment_id=f"{opts.env_name}_workers{opts.num_worker}_seed{opts.seed}"
    )
    
    # Print experiment summary
    summary = checkpoint_manager.get_experiment_summary()
    print(f"\n{'='*80}")
    print(f"CHECKPOINT SYSTEM STATUS")
    print(f"{'='*80}")
    print(f"Experiment directory: {summary['experiment_dir']}")
    print(f"Completed strategies: {summary['completed_strategies']}")
    print(f"Total completed: {summary['total_completed']}")
    
    all_strategy_results = {}
    
    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"CHECKING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")
        
        # Check if results already exist
        if checkpoint_manager.strategy_exists(strategy_name):
            print(f"LOADING EXISTING RESULTS for {strategy_name}...")
            existing_results = checkpoint_manager.load_strategy_results(strategy_name)
            if existing_results:
                all_strategy_results[strategy_name] = existing_results
                print(f"Successfully loaded {strategy_name} results!")
                continue
        
        print(f"RUNNING NEW EXPERIMENT for {strategy_name}...")
        
        # Configure for this strategy
        if strategy_name == 'baseline':
            opts.enable_game_theory = False
        else:
            opts.enable_game_theory = True
        
        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Setup save directory for this strategy
        if not opts.no_saving:
            strategy_save_dir = os.path.join(opts.save_dir, strategy_name)
            os.makedirs(strategy_save_dir, exist_ok=True)
            
            # Save arguments for this strategy
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            opts_dict['current_strategy'] = strategy_name
            with open(os.path.join(strategy_save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = strategy_name
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store and save results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
        
        # Save checkpoint
        config_dict = vars(opts).copy()
        config_dict['device'] = str(config_dict['device'])  # Make serializable
        checkpoint_manager.save_strategy_results(strategy_name, strategy_results, config_dict)
        
        all_strategy_results[strategy_name] = strategy_results
        
        print(f"\nStrategy {strategy_name.upper()} completed and saved!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")
    
    # Generate comparative plots
    print(f"\n{'='*60}")
    print(f"GENERATING COMPARATIVE ANALYSIS")
    print(f"{'='*60}")
    
    # Create agent for plotting (use last agent)
    if 'agent' in locals():
        agent.save_training_plots(opts.save_dir, all_strategy_results)
    
    # Print final comparison
    print(f"\n{'='*80}")
    print(f"FINAL STRATEGY COMPARISON RESULTS")
    print(f"{'='*80}")
    
    for strategy_name, results in all_strategy_results.items():
        if results['training_values'] and results['eval_values']:
            print(f"{strategy_name:15s}: Train={results['training_values'][-1]:6.2f}, "
                  f"Val={results['eval_values'][-1]:6.2f}, "
                  f"Improvement={results['training_values'][-1] - results['training_values'][0]:+6.2f}")
    
    # Save final summary
    final_summary = checkpoint_manager.get_experiment_summary()
    final_summary['final_results'] = {k: {
        'final_train': v['training_values'][-1] if v['training_values'] else 0,
        'final_val': v['eval_values'][-1] if v['eval_values'] else 0
    } for k, v in all_strategy_results.items()}
    
    summary_path = os.path.join(checkpoint_manager.checkpoint_dir, 'experiment_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(final_summary, f, indent=2)
    
    print(f"\nExperiment complete! All results saved to: {checkpoint_manager.checkpoint_dir}")
    
    return all_strategy_results


# Update the main run function to use checkpoints
def run_with_checkpoints(opts):
    """Main execution function with checkpoint support"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies with checkpoint support
        return run_strategy_comparison_with_checkpoints(opts)
    else:
        # Original single strategy run (could also add checkpoint support here)
        return run(opts)


def run_with_checkpoints_and_csv(opts):
    """Main execution function with checkpoint and CSV support"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies with checkpoint and CSV support
        return run_strategy_comparison_with_csv_export(opts)
    else:
        # Original single strategy run with CSV export
        import pprint
        pprint.pprint(vars(opts))
        
        # Setup tensorboard
        if not opts.no_tb:
            tb_writer = SummaryWriter(opts.log_dir)
        else:
            tb_writer = None

        # Setup directories
        if not opts.no_saving and not os.path.exists(opts.save_dir):
            os.makedirs(opts.save_dir)

        # Save arguments
        if not opts.no_saving:
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent
        agent = Agent(opts)
        
        # Determine strategy name
        if opts.enable_game_theory:
            strategy_name = getattr(agent.game_integrator, 'current_strategy', 'game_theory')
        else:
            strategy_name = 'baseline'
        
        # Run training/evaluation
        if opts.eval_only:
            torch.manual_seed(opts.seed)
            np.random.seed(opts.seed)
            
            if opts.load_path is not None:
                agent.load(opts.load_path)
            
            agent.start_validating(tb_writer, 0, opts.val_max_steps, opts.render, mode=opts.mode)
            
        else:
            for run_id in opts.seeds:
                torch.manual_seed(run_id)
                np.random.seed(run_id)
                
                nn_parms_worker = Worker(
                    id=0, env_name=opts.env_name, gamma=opts.gamma,
                    hidden_units=opts.hidden_units, activation=opts.activation, 
                    output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                    opts=opts
                ).to(opts.device)
                
                model_actor = get_inner_model(agent.master.logits_net)
                model_actor.load_state_dict({**model_actor.state_dict(), 
                                           **get_inner_model(nn_parms_worker.logits_net).state_dict()})
            
                agent.start_training(tb_writer, run_id)
                if tb_writer:
                    agent.log_performance(tb_writer)
            
            # Save training plots and CSV
            print("\n" + "="*60)
            print("Training completed! Generating plots and saving CSV...")
            agent.save_training_plots()
            agent.save_performance_to_csv(strategy_name)
            
            if opts.enable_game_theory:
                agent.save_game_theory_results()
                print("Game theory results saved!")
            
            print("CSV performance data saved!")
            print("="*60)
        
        return agent

if __name__ == "__main__":
    
    # Create configuration
    opts = Config()
    
    # Print which algorithm is running
    assert opts.SVRPG + opts.FedPG_BR <= 1
    
    if opts.enable_game_theory:
        print('='*80)
        print('FEDERATED REINFORCEMENT LEARNING WITH STACKELBERG GAME THEORY')
        print('='*80)
        print(f'Environment: {opts.env_name}')
        print(f'Workers: {opts.num_worker}')
        print(f'Byzantine Workers: {opts.num_Byzantine}')
        print(f'Game Theory Scenario: {opts.damage_scenario}')
        print(f'Defense Budget Ratio: {opts.initial_defense_budget_ratio}')
        print(f'Attack Budget Ratio: {opts.initial_attack_budget_ratio}')
        print(f'Base Attack Intensity: {opts.base_attack_intensity}')
        print(f'Defense Effectiveness: {opts.defense_effectiveness}')
        print(f'Reshuffle Frequency: every {opts.reshuffle_frequency} rounds')
        print('='*80)
        if opts.SVRPG + opts.FedPG_BR == 0:
            print('Running: GOMDP with Stackelberg Game Theory')
        elif opts.FedPG_BR:
            print('Running: FT-FedScsPG with Stackelberg Game Theory')
        else:
            print('Running: SVRPG with Stackelberg Game Theory')
    else:
        print('run GPMDP\n' if opts.SVRPG + opts.FedPG_BR == 0 else ('run FT-FedScsPG\n' if opts.FedPG_BR else 'run SVRPG\n'))
    
    # MODIFIED: Use enhanced CSV export version
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Use the enhanced version with CSV export
        all_strategy_results = run_strategy_comparison_with_csv_export(opts)
        
        print(f"\n{'='*80}")
        print(f"CSV FILES SAVED!")
        print(f"{'='*80}")
        print(f"Check the following directory for CSV files:")
        csv_dir = os.path.join(opts.save_dir if opts.save_dir else './results', 'performance_csv')
        print(f"  {csv_dir}")
        print(f"Files generated:")
        print(f"  - Individual strategy files: [strategy_name]_performance.csv")
        print(f"  - Combined file: all_strategies_performance.csv")
        print(f"  - Training data: training_returns_all_strategies.csv")
        print(f"  - Validation data: validation_returns_all_strategies.csv")
        print(f"  - Plotting ready: all_returns_plotting_ready.csv")
        
    else:
        # Run the original training with CSV saving
        run_with_checkpoints(opts)