#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.multiprocessing import Pool
from torch.nn import DataParallel
from torch.distributions.categorical import Categorical
from torch.distributions import Normal
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import gym
from gym.spaces import Discrete
import random
import math
import time
import warnings
from tqdm import tqdm
from sklearn import metrics
from matplotlib import pyplot as plt
from matplotlib import animation
from itertools import repeat
from scipy.interpolate import Rbf
import scipy.stats as st
import numpy as np
if not hasattr(np, "bool8"):
    np.bool8 = np.bool_

# ======================== GAME THEORY IMPORTS ========================
import copy
import csv
from collections import defaultdict
from typing import List, Dict, Tuple, Set, Optional
from itertools import combinations
from scipy.optimize import minimize
import pickle

# Suppress warnings
warnings.filterwarnings("ignore", category=Warning)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# =============================================================================
# CONFIGURATION - Modify these parameters as needed
# =============================================================================
import csv
import os
import pandas as pd

class PerformanceCSVSaver:
    """Enhanced CSV saver for detailed performance tracking"""
    
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        self.csv_dir = os.path.join(save_dir, 'performance_csv')
        os.makedirs(self.csv_dir, exist_ok=True)
        
    def save_strategy_performance_csv(self, strategy_name: str, memory_data: dict,
                                    game_theory_metrics: dict = None):
        """Save per-seed CSVs in seed_{run_id}/ subfolders + a combined all-seeds CSV."""

        if not memory_data['steps']:
            print(f"No data to save for strategy: {strategy_name}")
            return

        all_seeds_data = []

        for run_id in memory_data['steps'].keys():
            steps           = memory_data['steps'].get(run_id, [])
            training_values = memory_data['training_values'].get(run_id, [])
            eval_values     = memory_data['eval_values'].get(run_id, [])

            if not steps:
                continue

            # Per-seed subfolder inside performance_csv/
            seed_dir = os.path.join(self.csv_dir, f'seed_{run_id}')
            os.makedirs(seed_dir, exist_ok=True)

            csv_data = []
            max_len = max(len(steps), len(training_values), len(eval_values))

            for i in range(max_len):
                row = {
                    'seed': run_id,
                    'step': steps[i] if i < len(steps) else None,
                    'training_return': training_values[i] if i < len(training_values) else None,
                    'validation_return': eval_values[i] if i < len(eval_values) else None,
                    'strategy': strategy_name
                }

                if (game_theory_metrics and run_id in game_theory_metrics and
                        i < len(game_theory_metrics[run_id])):
                    m = game_theory_metrics[run_id][i]
                    row.update({
                        'num_attackers':        m.get('num_attackers', 0),
                        'num_defended':         m.get('num_defended', 0),
                        'performance_drop':     m.get('performance_drop', 0),
                        'theoretical_damage':   m.get('theoretical_damage', 0),
                        'successful_attacks':   m.get('successful_attacks', 0),
                        'defense_budget_util':  m.get('defense_budget_utilization', 0),
                        'attack_budget_util':   m.get('attack_budget_utilization', 0),
                        'avg_attacker_utility': m.get('avg_attacker_utility', 0),
                        'current_era':          m.get('current_era', 0),
                        'deterrence_efficiency':m.get('deterrence_efficiency', 0)
                    })

                csv_data.append(row)
                all_seeds_data.append(row)

            # Per-seed CSV
            seed_csv = os.path.join(seed_dir, f'{strategy_name}_performance.csv')
            pd.DataFrame(csv_data).to_csv(seed_csv, index=False)
            print(f"  Saved seed {run_id}: {seed_csv}  ({len(csv_data)} rows)")

        # All-seeds combined CSV for this strategy
        if all_seeds_data:
            all_seeds_csv = os.path.join(self.csv_dir, f'{strategy_name}_all_seeds.csv')
            pd.DataFrame(all_seeds_data).to_csv(all_seeds_csv, index=False)
            print(f"Saved all-seeds CSV for {strategy_name}: {all_seeds_csv}")
    
    def save_combined_performance_csv(self, all_strategy_results: dict):
        """Save all strategies in one combined CSV file"""
        combined_data = []
        
        for strategy_name, results in all_strategy_results.items():
            steps = results.get('steps', [])
            training_values = results.get('training_values', [])
            eval_values = results.get('eval_values', [])
            game_metrics = results.get('game_theory_metrics', [])
            
            max_len = max(len(steps), len(training_values), len(eval_values))
            
            for i in range(max_len):
                row = {
                    'strategy': strategy_name,
                    'step': steps[i] if i < len(steps) else np.nan,
                    'training_return': training_values[i] if i < len(training_values) else np.nan,
                    'validation_return': eval_values[i] if i < len(eval_values) else np.nan
                }
                
                # Add game theory metrics if available
                if game_metrics and i < len(game_metrics):
                    metrics = game_metrics[i]
                    row.update({
                        'num_attackers': metrics.get('num_attackers', 0),
                        'num_defended': metrics.get('num_defended', 0),
                        'performance_drop': metrics.get('performance_drop', 0),
                        'theoretical_damage': metrics.get('theoretical_damage', 0),
                        'successful_attacks': metrics.get('successful_attacks', 0),
                        'defense_budget_util': metrics.get('defense_budget_utilization', 0),
                        'attack_budget_util': metrics.get('attack_budget_utilization', 0),
                        'current_era': metrics.get('current_era', 0)
                    })
                else:
                    # Fill with zeros for non-game-theory strategies
                    row.update({
                        'num_attackers': 0,
                        'num_defended': 0,
                        'performance_drop': 0,
                        'theoretical_damage': 0,
                        'successful_attacks': 0,
                        'defense_budget_util': 0,
                        'attack_budget_util': 0,
                        'current_era': 0
                    })
                
                combined_data.append(row)
        
        # Save combined CSV
        combined_filename = os.path.join(self.csv_dir, 'all_strategies_performance.csv')
        
        if combined_data:
            df = pd.DataFrame(combined_data)
            df.to_csv(combined_filename, index=False)
            print(f"Saved combined performance data: {combined_filename}")
            print(f"  Total rows: {len(df)}")
            print(f"  Strategies: {df['strategy'].unique().tolist()}")
    
    def create_plotting_ready_csv(self, all_strategy_results: dict):
        """Create CSV optimized for plotting with matplotlib/seaborn"""
        
        # Create separate files for training and validation
        training_data = []
        validation_data = []
        
        for strategy_name, results in all_strategy_results.items():
            steps = results.get('steps', [])
            training_values = results.get('training_values', [])
            eval_values = results.get('eval_values', [])
            
            # Training data
            for step, train_val in zip(steps, training_values):
                training_data.append({
                    'step': step,
                    'return': train_val,
                    'strategy': strategy_name,
                    'type': 'training'
                })
            
            # Validation data
            for step, eval_val in zip(steps, eval_values):
                validation_data.append({
                    'step': step,
                    'return': eval_val,
                    'strategy': strategy_name,
                    'type': 'validation'
                })
        
        # Save training data
        if training_data:
            training_df = pd.DataFrame(training_data)
            training_csv = os.path.join(self.csv_dir, 'training_returns_all_strategies.csv')
            training_df.to_csv(training_csv, index=False)
            print(f"Saved training data: {training_csv}")
        
        # Save validation data  
        if validation_data:
            validation_df = pd.DataFrame(validation_data)
            validation_csv = os.path.join(self.csv_dir, 'validation_returns_all_strategies.csv')
            validation_df.to_csv(validation_csv, index=False)
            print(f"Saved validation data: {validation_csv}")
        
        # Combined file for easy plotting
        all_data = training_data + validation_data
        if all_data:
            all_df = pd.DataFrame(all_data)
            all_csv = os.path.join(self.csv_dir, 'all_returns_plotting_ready.csv')
            all_df.to_csv(all_csv, index=False)
            print(f"Saved plotting-ready data: {all_csv}")

    def save_grand_combined_csv(self, all_memory_per_strategy: dict):
        """
        One CSV with every strategy x every seed x every step.
        Columns: strategy, seed, step, training_return, validation_return
        This is what you load in pandas to compute mean +/- std across seeds.
        """
        rows = []
        for strategy_name, memory_data in all_memory_per_strategy.items():
            for run_id, steps in memory_data['steps'].items():
                training_values = memory_data['training_values'].get(run_id, [])
                eval_values     = memory_data['eval_values'].get(run_id, [])
                max_len = max(len(steps), len(training_values), len(eval_values))
                for i in range(max_len):
                    rows.append({
                        'strategy':         strategy_name,
                        'seed':             run_id,
                        'step':             steps[i] if i < len(steps) else None,
                        'training_return':  training_values[i] if i < len(training_values) else None,
                        'validation_return':eval_values[i] if i < len(eval_values) else None,
                    })

        if rows:
            grand_csv = os.path.join(self.csv_dir, 'grand_all_strategies_all_seeds.csv')
            pd.DataFrame(rows).to_csv(grand_csv, index=False)
            print(f"Saved grand combined CSV: {grand_csv}  ({len(rows)} rows)")

def save_detailed_performance_csv(self, save_dir=None, strategy_name='default'):
    """Enhanced method to save detailed performance data"""
    if save_dir is None:
        save_dir = self.opts.save_dir if self.opts.save_dir else './results'
    
    saver = PerformanceCSVSaver(save_dir)
    
    memory_data = {
        'steps': self.memory.steps,
        'training_values': self.memory.training_values,
        'eval_values': self.memory.eval_values
    }
    
    game_metrics = getattr(self.memory, 'game_theory_metrics', None)
    
    saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)

def _strategy_performance_csv_candidates(base_save_dir: str, strategy_name: str) -> List[str]:
    """Return possible strategy CSV paths used by different save flows."""
    return [
        os.path.join(base_save_dir, "performance_csv", f"{strategy_name}_performance.csv"),
        os.path.join(base_save_dir, strategy_name, "performance_csv", f"{strategy_name}_performance.csv"),
    ]

def _find_completed_strategy_artifact(base_save_dir: str, strategy_name: str) -> Optional[str]:
    """
    Return an artifact path if a strategy appears completed.
    Priority:
      1) performance CSV
      2) per-strategy checkpoint files (r*-epoch-*.pt)
    """
    for path in _strategy_performance_csv_candidates(base_save_dir, strategy_name):
        if os.path.exists(path):
            return path
    strategy_dir = os.path.join(base_save_dir, strategy_name)
    if os.path.isdir(strategy_dir):
        for name in os.listdir(strategy_dir):
            if name.startswith("r") and "-epoch-" in name and name.endswith(".pt"):
                return os.path.join(strategy_dir, name)
    return None

def _rebuild_grand_csv_from_individual(base_save_dir: str):
    """
    Merge all per-strategy *_all_seeds.csv files already present in
    performance_csv/ into grand_all_strategies_all_seeds.csv.
    This keeps reruns safe when only a subset of strategies is executed.
    """
    import glob as _glob
    import pandas as _pd

    csv_dir = os.path.join(base_save_dir, 'performance_csv')
    parts = []
    for path in _glob.glob(os.path.join(csv_dir, '*_all_seeds.csv')):
        if 'grand' in os.path.basename(path):
            continue
        df = _pd.read_csv(path)
        keep = [c for c in ('strategy', 'seed', 'step', 'training_return', 'validation_return') if c in df.columns]
        parts.append(df[keep])
    if not parts:
        print('[WARN] No individual strategy CSVs found; grand CSV not written.')
        return
    grand = _pd.concat(parts, ignore_index=True)
    cols = [c for c in ('strategy', 'seed', 'step', 'training_return', 'validation_return') if c in grand.columns]
    grand = grand[cols]
    out = os.path.join(csv_dir, 'grand_all_strategies_all_seeds.csv')
    grand.to_csv(out, index=False)
    print(f'Grand CSV written: {out} ({len(grand):,} rows)')

def run_strategy_comparison_with_csv_export(opts):
    """Strategy comparison with detailed CSV export"""
    import pprint
    pprint.pprint(vars(opts))
    
    all_strategy_results = {}
    
    # Initialize CSV saver
    base_save_dir = opts.save_dir if opts.save_dir else './results'
    csv_saver = PerformanceCSVSaver(base_save_dir)

    # Tracks full memory (all seeds) per strategy for the grand combined CSV
    all_memory_per_strategy = {}

    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"RUNNING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")

        # Configure for this strategy
        system_level = {'fltg', 'fedgreed'}
        # Skip any strategy that already has saved artifacts.
        existing_artifact = _find_completed_strategy_artifact(base_save_dir, strategy_name)
        if existing_artifact:
            print(f"Skipping {strategy_name} (found {existing_artifact}).")
            continue
        if strategy_name == 'clean_baseline':
            # Clean run: no game theory attacks or defenses
            opts.enable_game_theory = False
            opts.aggregation_defense = 'none'
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name in system_level:
            opts.enable_game_theory = True
            opts.aggregation_defense = strategy_name
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name == 'baseline':
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = 'no_defense'
        elif strategy_name in {'stackelberg_fltg', 'ucb_fltg', 'thompson_fltg',
                               'stackelberg_fedgreed', 'ucb_fedgreed', 'thompson_fedgreed'}:
            _combined_map = {
                'stackelberg_fltg':      ('true_stackelberg', 'fltg'),
                'ucb_fltg':              ('ucb',              'fltg'),
                'thompson_fltg':         ('thompson_sampling', 'fltg'),
                'stackelberg_fedgreed':  ('true_stackelberg',  'fedgreed'),
                'ucb_fedgreed':          ('ucb',               'fedgreed'),
                'thompson_fedgreed':     ('thompson_sampling',  'fedgreed'),
            }
            client_strategy, agg_defense = _combined_map[strategy_name]
            opts.enable_game_theory = True
            opts.aggregation_defense = agg_defense
            opts.use_client_defense = True
            current_strategy = client_strategy
        else:
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = strategy_name
        opts.current_strategy = current_strategy
        opts.current_strategy = current_strategy
        
        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Set per-strategy save dir for debug artifacts
        if not opts.no_saving:
            strategy_save_dir = os.path.join(base_save_dir, strategy_name)
            os.makedirs(strategy_save_dir, exist_ok=True)
            opts.save_dir = strategy_save_dir
        else:
            opts.save_dir = None

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = current_strategy
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
            
        all_strategy_results[strategy_name] = strategy_results
        
        # Save individual strategy CSV (all seeds → seed_{run_id}/ subfolders)
        memory_data = {
            'steps': agent.memory.steps,
            'training_values': agent.memory.training_values,
            'eval_values': agent.memory.eval_values
        }
        game_metrics = getattr(agent.memory, 'game_theory_metrics', None)
        csv_saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)

        # Store full memory for grand combined CSV at the end
        all_memory_per_strategy[strategy_name] = {
            'steps':            dict(agent.memory.steps),
            'training_values':  dict(agent.memory.training_values),
            'eval_values':      dict(agent.memory.eval_values),
        }

        print(f"\nStrategy {strategy_name.upper()} completed and CSV saved!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")

        # Restore base save dir for next strategy
        opts.save_dir = base_save_dir
    
    # Save combined CSV files
    print(f"\n{'='*60}")
    print(f"SAVING COMBINED CSV FILES")
    print(f"{'='*60}")
    
    csv_saver.save_combined_performance_csv(all_strategy_results)
    csv_saver.create_plotting_ready_csv(all_strategy_results)

    # Grand combined: rebuild from all per-strategy CSVs already in the folder
    # so targeted reruns do not drop previously completed strategies.
    csv_saver.save_grand_combined_csv(all_memory_per_strategy)
    _rebuild_grand_csv_from_individual(base_save_dir)

    # Generate comparative plots (your existing plotting code)
    if 'agent' in locals():
        agent.save_training_plots(opts.save_dir, all_strategy_results)

    return all_strategy_results
class Config:
    """Configuration class for easy parameter modification"""
    
    # Environment and run settings
    env_name = 'CartPole-v1'  # Options: 'CartPole-v1', 'HalfCheetah-v2', 'LunarLander-v2'
    eval_only = False
    no_saving = False
    no_tb = False
    render = False
    mode = 'human'  # 'human' or 'rgb'
    log_dir = 'logs'
    run_name = 'run_name'
    
    # Multiple runs
    multiple_run = 1
    seed = 2
    debug_summary = True
    debug_summary_rounds = 200
    debug_summary_include_rounds = True
    
    # Federation parameters
    num_worker = 30
    alpha = 0.4
    num_Byzantine = 0
    # RL Algorithms
    SVRPG = False
    FedPG_BR = False
    attack_type = None
    # Training and validation
    val_size = 10
    val_max_steps = 1000
    load_path = None
    
    # Device settings
    use_cuda = False
    
    # ======================== GAME THEORY PARAMETERS ========================
    # Game theory integration parameters
    enable_game_theory = True
    compare_all_strategies = True  # Run all strategies and compare results
    strategies_to_compare = [
        'clean_baseline',
        'no_defense',
        'random',
        'true_stackelberg',
        'fltg',
        'fedgreed',
        'ucb',
        'thompson_sampling',
        'stackelberg_fltg',
        'ucb_fltg',
        'thompson_fltg',
        'stackelberg_fedgreed',
        'ucb_fedgreed',
        'thompson_fedgreed',
    ]
    initial_defense_budget_ratio = 0.3
    initial_attack_budget_ratio = 0.3
    defense_effectiveness = 1.0
    observation_accuracy = 0.8
    # Probabilistic defense (used instead of client-level defense)
    defense_strength = 0.8
    use_client_defense = False
    # System-level aggregation defenses (apply at server)
    aggregation_defense = 'none'  # Options: 'none', 'fltg', 'fedgreed'
    fedgreed_fraction = 0.7       # Fraction of clients selected by FedGreed (lowest distance)
    ucb_c = 2.0
    foundationfl_synth_ratio = 0.5
    foundationfl_trim_ratio = 0.2
    foundationfl_noise_scale = 0.1
    huber_mad_k = 2.5
    huber_iters = 5
    base_attack_intensity = 0.5
    attack_style = "grad_noise"  # Options: "action_flip", "reward_poison", "obs_noise", "grad_noise"
    attack_flip_prob = 0.25
    reward_poison_scale = 0.5
    obs_noise_std = 0.2
    damage_scenario = "engineered"  # "critical_infrastructure", "tiered_importance", "network_hubs", "engineered"
    reshuffle_frequency = 100 # Reshuffe game parameters every N rounds
    single_best_attacker_only = False


    def __init__(self):
        # Set environment-specific hyperparameters
        self.set_environment_params()
        
        # Setup paths and run name
        self.setup_paths()
        
    def set_environment_params(self):
        """Set hyperparameters based on environment"""
        if self.env_name == 'CartPole-v1':
            # Task-Specified Hyperparameters
            self.max_epi_len = 500
            self.max_trajectories = 5000
            self.gamma = 0.999
            self.min_reward = 0
            self.max_reward = 600

            # Shared parameters
            self.do_sample_for_training = True
            self.lr_model = 1e-3
            self.hidden_units = '16,16'
            self.activation = 'ReLU'
            self.output_activation = 'Tanh'
            
            # Batch sizes
            self.B = 16  # for SVRPG and GOMDP
            self.Bmin = 12  # for FT-FedScsPG
            self.Bmax = 20  # for FT-FedScsPG
            self.b = 4  # mini batch_size for SVRPG and FT-FedScsPG
            
            # Inner loop iteration for SVRPG
            self.N = 3
            
            # Filtering hyperparameters for FT-FedScsPG
            self.delta = 0.6
            self.sigma = 0.06

        elif self.env_name == 'HalfCheetah-v2':
            self.max_epi_len = 500  
            self.max_trajectories = 1e4
            self.gamma = 0.995
            self.min_reward = 0
            self.max_reward = 4000
            
            self.do_sample_for_training = True
            self.lr_model = 8e-5
            self.hidden_units = '64,64'
            self.activation = 'Tanh'
            self.output_activation = 'Tanh'
           
            self.B = 48
            self.Bmin = 46
            self.Bmax = 50
            self.b = 16
            self.N = 3
            self.delta = 0.6
            self.sigma = 0.9

        elif self.env_name == 'LunarLander-v2':
            self.max_epi_len = 1000  
            self.max_trajectories = 1e4
            self.gamma = 0.99
            self.min_reward = -1000
            self.max_reward = 300
            
            self.do_sample_for_training = True
            self.lr_model = 1e-3
            self.hidden_units = '64,64'
            self.activation = 'Tanh'
            self.output_activation = 'Tanh'
            
            self.B = 32
            self.Bmin = 26
            self.Bmax = 38
            self.b = 8
            self.N = 3
            self.delta = 0.6
            self.sigma = 0.07
            
    def setup_paths(self):
        """Setup directory paths"""
        
        if not self.no_saving:
            self.save_dir = os.path.join(
                f'outputs_{self.env_name}_workers{self.num_worker}_rounds{self.max_trajectories}_shuffle_{self.reshuffle_frequency}_single_attacker_{self.single_best_attacker_only}_attack_{self.base_attack_intensity}_seed{self.seed}_attack_style_{self.attack_style}',
            )
        else:
            self.save_dir = None
            
        if not self.no_tb:
            self.log_dir = os.path.join(
                f'{self.log_dir}_{self.env_name}_workers{self.num_worker}_rounds{self.max_trajectories}_shuffle_{self.reshuffle_frequency}_single_attacker_{self.single_best_attacker_only}_attack_{self.base_attack_intensity}_seed{self.seed}_attack_style_{self.attack_style}',
            )
        else:
            self.log_dir = None
            
        # Set device
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        
        # Set seeds
        self.seeds = (np.arange(self.multiple_run) + self.seed).tolist()

# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

def torch_load_cpu(load_path):
    return torch.load(load_path, map_location=lambda storage, loc: storage)

def get_inner_model(model):
    return model.module if isinstance(model, DataParallel) else model

def move_to(var, device):
    if isinstance(var, dict):
        return {k: move_to(v, device) for k, v in var.items()}
    return var.to(device)

def env_wrapper(name, obs):
    return obs

def save_frames_as_gif(frames, path='./', filename='gym_animation.gif'):
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=50)
    anim.save(path + filename, writer='imagemagick', fps=120)

def euclidean_dist(x, y):
    """
    Args:
      x: pytorch Variable, with shape [m, d]
      y: pytorch Variable, with shape [n, d]
    Returns:
      dist: pytorch Variable, with shape [m, n]
    """
    m, n = x.size(0), y.size(0)
    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).T
    dist = xx + yy
    dist.addmm_(1, -2, x, y.T)
    dist[dist < 0] = 0
    dist = dist.sqrt()
    return dist

# ======================== GAME THEORY UTILITY FUNCTIONS ========================

def fedavg(parameters_list):
    """Federated averaging of model parameters"""
    with torch.no_grad():
        new_params = {}
        for name in parameters_list[0].keys():
            new_params[name] = torch.zeros_like(parameters_list[0][name])
            for param in parameters_list:
                new_params[name] += param[name]
            new_params[name] /= len(parameters_list)
    return new_params

def create_realistic_unequal_damage_weights(num_clients: int, scenario: str = "critical_infrastructure", seed: int = 42) -> np.ndarray:
    """
    Create realistic unequal damage weights that will be reflected in attack severity
    
    Even though all clients have equal data, some are more strategically valuable
    """
    np.random.seed(seed)
    
    if scenario == "critical_infrastructure":
        # Some clients are critical infrastructure (hospitals, banks)
        # Others are regular IoT devices
        damage_weights = np.ones(num_clients) * 3.0  # Base damage for regular clients
        
        # Randomly assign some clients as critical (20% of clients)
        num_critical = max(1, num_clients // 5)
        critical_indices = np.random.choice(num_clients, num_critical, replace=False)
        
        for idx in critical_indices:
            damage_weights[idx] = 15.0  # Critical clients have 5x more strategic value
        
        print(f"CRITICAL INFRASTRUCTURE SCENARIO:")
        print(f"   Critical clients: {sorted(critical_indices)}")
        print(f"   Regular damage weight: 3.0, Critical damage weight: 15.0")
        
    elif scenario == "tiered_importance":
        # Create 3 tiers of importance
        damage_weights = np.ones(num_clients) * 2.0  # Tier 3 (low value)
        
        # Tier 2 (medium value) - 30% of clients
        tier2_count = max(1, int(num_clients * 0.3))
        tier2_indices = np.random.choice(num_clients, tier2_count, replace=False)
        for idx in tier2_indices:
            damage_weights[idx] = 8.0
        
        # Tier 1 (high value) - 20% of clients  
        remaining_indices = [i for i in range(num_clients) if i not in tier2_indices]
        tier1_count = max(1, int(num_clients * 0.2))
        tier1_indices = np.random.choice(remaining_indices, tier1_count, replace=False)
        for idx in tier1_indices:
            damage_weights[idx] = 20.0
            
        print(f"TIERED IMPORTANCE SCENARIO:")
        print(f"   Tier 1 (high value): {sorted(tier1_indices)} → weight: 20.0")
        print(f"   Tier 2 (medium value): {sorted(tier2_indices)} → weight: 8.0") 
        print(f"   Tier 3 (low value): remaining → weight: 2.0")
        
    elif scenario == "network_hubs":
        # Some clients are network hubs with higher connectivity
        damage_weights = np.ones(num_clients) * 4.0  # Base damage
        
        # Select hub clients (15% of total)
        num_hubs = max(1, int(num_clients * 0.15))
        hub_indices = np.random.choice(num_clients, num_hubs, replace=False)
        
        for idx in hub_indices:
            damage_weights[idx] = 18.0  # Hub clients have 4.5x more impact
        
        print(f"NETWORK HUBS SCENARIO:")
        print(f"   Hub clients: {sorted(hub_indices)}")
        print(f"   Regular damage weight: 4.0, Hub damage weight: 18.0")
        
    elif scenario == "engineered":
        # Placeholder: engineered scenario handled in GameTheoryIntegrator
        damage_weights = np.ones(num_clients) * 1.0
        print(f"ENGINEERED SCENARIO: damage weights set in GameTheoryIntegrator")

    else:  # default to critical_infrastructure
        return create_realistic_unequal_damage_weights(num_clients, "critical_infrastructure", seed)
    
    print(f"   Damage range: {damage_weights.min():.1f} to {damage_weights.max():.1f}")
    print(f"   Ratio: {damage_weights.max()/damage_weights.min():.1f}x difference")
    
    return damage_weights


def create_strategic_variety_costs(num_clients: int, base_values: np.ndarray,
                                 cost_multiplier_range: tuple = (0.5, 1.5),
                                 seed: int = 42) -> np.ndarray:
    """
    Create cost variety for strategic game dynamics (tied to damage weights).
    """
    np.random.seed(seed)
    cost_multipliers = np.random.uniform(cost_multiplier_range[0], cost_multiplier_range[1], num_clients)
    costs = base_values * cost_multipliers
    noise = np.random.normal(0, 0.2, num_clients)
    costs += noise
    costs = np.maximum(costs, 0.1)
    return costs

# ======================== KNAPSACK PROBLEM SOLVERS ========================

def solve_knapsack_dp(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """
    Solve 0/1 knapsack using dynamic programming efficiently
    
    FIXES APPLIED:
    - Corrected backtracking logic for optimal solutions
    - Integer overflow protection with dynamic scaling
    - Memory optimization for large instances
    """
    n = len(values)
    
    if n == 0:
        return set()
    
    # FIX #2: Integer overflow protection with dynamic scaling
    max_value = max(budget, costs.max()) if len(costs) > 0 else budget
    safe_scale_factor = min(1000, int(2**30 / max_value)) if max_value > 0 else 1000
    
    budget_int = int(budget * safe_scale_factor)
    costs_int = (costs * safe_scale_factor).astype(np.int64)  # Use int64 for safety
    
    # FIX #3: Memory optimization - use space-efficient DP for large instances
    memory_threshold = 50_000_000  # ~50MB threshold
    use_space_optimized = (n + 1) * (budget_int + 1) > memory_threshold
    
    if use_space_optimized:
        # Space-optimized O(budget) memory approach
        return _solve_knapsack_space_optimized(values, costs_int, budget_int, safe_scale_factor)
    else:
        # Standard O(n × budget) memory approach with full backtracking
        return _solve_knapsack_full_table(values, costs_int, budget_int)

def _solve_knapsack_full_table(values: np.ndarray, costs_int: np.ndarray, budget_int: int) -> Set[int]:
    """Standard DP with full table for smaller instances"""
    n = len(values)
    
    # Full DP table
    dp = np.zeros((n + 1, budget_int + 1), dtype=np.float32)  # Use float32 to save memory
    
    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            # Don't take item i-1
            dp[i][w] = dp[i-1][w]
            
            # Take item i-1 if profitable
            if costs_int[i-1] <= w and values[i-1] > 0:
                dp[i][w] = max(dp[i][w], dp[i-1][w - costs_int[i-1]] + values[i-1])
    
    # Backtracking to find selected items
    selected = set()
    w = budget_int
    for i in range(n, 0, -1):
        # Check if item i-1 was actually taken in optimal solution
        if w >= costs_int[i-1] and abs(dp[i][w] - (dp[i-1][w - costs_int[i-1]] + values[i-1])) < 1e-9:
            selected.add(i-1)
            w -= costs_int[i-1]
    
    return selected

def _solve_knapsack_space_optimized(values: np.ndarray, costs_int: np.ndarray, budget_int: int, scale_factor: int) -> Set[int]:
    """
    Space-optimized DP using O(budget) memory
    Runs algorithm twice: once for value, once for solution reconstruction
    """
    n = len(values)
    
    # First pass: compute optimal value using only two rows
    prev_row = np.zeros(budget_int + 1, dtype=np.float32)
    curr_row = np.zeros(budget_int + 1, dtype=np.float32)
    
    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            # Don't take item i-1
            curr_row[w] = prev_row[w]
            
            # Take item i-1 if profitable
            if costs_int[i-1] <= w and values[i-1] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i-1]] + values[i-1])
        
        # Swap rows
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)
    
    optimal_value = prev_row[budget_int]
    
    # Second pass: reconstruct solution by testing which items were taken
    selected = set()
    remaining_budget = budget_int
    remaining_value = optimal_value
    
    # Test items in reverse order
    for i in range(n - 1, -1, -1):
        if values[i] <= 0 or costs_int[i] > remaining_budget:
            continue
            
        # Check if removing this item reduces the optimal value
        # by running mini-DP without this item
        temp_budget = remaining_budget - costs_int[i]
        temp_value = remaining_value - values[i]
        
        # Quick check: if including this item achieves the remaining value target
        if abs(temp_value - _compute_optimal_value_without_item(
            values, costs_int, temp_budget, exclude_item=i)) < 1e-9:
            selected.add(i)
            remaining_budget = temp_budget
            remaining_value = temp_value
    
    return selected

def _compute_optimal_value_without_item(values: np.ndarray, costs_int: np.ndarray, 
                                      budget: int, exclude_item: int) -> float:
    """Helper function to compute optimal value excluding one specific item"""
    if budget <= 0:
        return 0.0
        
    prev_row = np.zeros(budget + 1, dtype=np.float32)
    curr_row = np.zeros(budget + 1, dtype=np.float32)
    
    for i in range(len(values)):
        if i == exclude_item:
            continue  # Skip the excluded item
            
        for w in range(budget + 1):
            curr_row[w] = prev_row[w]
            if costs_int[i] <= w and values[i] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i]] + values[i])
        
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)
    
    return prev_row[budget]

def solve_knapsack_greedy(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """Solve knapsack using greedy heuristic efficiently"""
    # Calculate efficiency (value per unit cost)
    efficiency = np.divide(values, costs, out=np.zeros_like(values), where=costs!=0)
    
    # Sort by efficiency (descending)
    sorted_indices = np.argsort(efficiency)[::-1]
    
    selected = set()
    current_cost = 0.0
    
    for idx in sorted_indices:
        if current_cost + costs[idx] <= budget and values[idx] > 0:
            selected.add(idx)
            current_cost += costs[idx]
    
    return selected

def generate_feasible_defense_combinations(defense_costs: np.ndarray, budget: float, max_combinations: int = 10000):
    """Generate feasible defense combinations within budget using smart enumeration"""
    num_clients = len(defense_costs)
    feasible_combinations = []
    
    # Start with greedy solution and variations
    efficiency = 1.0 / defense_costs  # Simple efficiency metric
    sorted_indices = np.argsort(efficiency)[::-1]
    
    # Method 1: Greedy and variations
    current_cost = 0.0
    current_combo = set()
    for idx in sorted_indices:
        if current_cost + defense_costs[idx] <= budget:
            current_combo.add(idx)
            current_cost += defense_costs[idx]
    feasible_combinations.append(current_combo)
    
    # Method 2: Generate variations by swapping elements
    base_combo = current_combo.copy()
    for _ in range(min(1000, max_combinations // 10)):
        # Try removing one element and adding another
        if len(base_combo) > 0:
            to_remove = random.choice(list(base_combo))
            new_combo = base_combo - {to_remove}
            new_cost = sum(defense_costs[i] for i in new_combo)
            
            # Try adding a different element
            remaining = set(range(num_clients)) - new_combo
            for candidate in remaining:
                if new_cost + defense_costs[candidate] <= budget:
                    candidate_combo = new_combo | {candidate}
                    if candidate_combo not in feasible_combinations:
                        feasible_combinations.append(candidate_combo)
                        if len(feasible_combinations) >= max_combinations:
                            return feasible_combinations
    
    # Method 3: Random sampling for smaller instances
    if num_clients <= 20:
        for _ in range(min(2000, max_combinations - len(feasible_combinations))):
            # Random subset
            subset_size = random.randint(1, min(num_clients, 10))
            subset = set(random.sample(range(num_clients), subset_size))
            cost = sum(defense_costs[i] for i in subset)
            if cost <= budget and subset not in feasible_combinations:
                feasible_combinations.append(subset)
    
    return feasible_combinations

def solve_knapsack_simple_greedy(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    """Simple greedy knapsack heuristic: pick by highest value-to-cost until budget."""
    efficiency = np.divide(values, costs, out=np.zeros_like(values), where=costs!=0)
    sorted_indices = np.argsort(efficiency)[::-1]
    selected = set()
    current_cost = 0.0
    for idx in sorted_indices:
        if values[idx] > 0 and current_cost + costs[idx] <= budget:
            selected.add(idx)
            current_cost += costs[idx]
    return selected

# ======================== PARTIAL VISIBILITY + PROB DEFENSE ========================

class ProbabilisticDefenseModel:
    """Defense succeeds probabilistically instead of perfectly blocking attacks."""

    def __init__(self, defense_strength: float = 0.7):
        self.defense_strength = float(np.clip(defense_strength, 0.0, 1.0))

    def get_attack_success_probability(self, is_defended: bool) -> float:
        return 1.0 - self.defense_strength if is_defended else 1.0

    def sample_attack_outcome(self, is_defended: bool, seed: Optional[int] = None) -> bool:
        rng = np.random.RandomState(seed) if seed is not None else np.random
        return rng.random() < self.get_attack_success_probability(is_defended)


class PartialObservabilityModel:
    """Noisy attacker observation of which clients are defended."""

    def __init__(self, observation_accuracy: float = 0.8):
        self.observation_accuracy = float(np.clip(observation_accuracy, 0.5, 1.0))

    def observe_defense_discrete(self, true_defended_set: Set[int],
                                 num_clients: int, seed: Optional[int] = None) -> np.ndarray:
        rng = np.random.RandomState(seed) if seed is not None else np.random
        obs = np.zeros(num_clients, dtype=int)
        for i in range(num_clients):
            is_def = i in true_defended_set
            if rng.random() < self.observation_accuracy:
                obs[i] = 1 if is_def else 0
            else:
                obs[i] = 0 if is_def else 1
        return obs

    def observe_attack_discrete(self, true_attackers: Set[int],
                                num_clients: int, seed: Optional[int] = None) -> Set[int]:
        """Noisy defender observation of who attacked — same accuracy as attacker observing defense."""
        rng = np.random.RandomState(seed) if seed is not None else np.random
        observed = set()
        for i in range(num_clients):
            true_is_attacker = i in true_attackers
            if rng.random() < self.observation_accuracy:
                if true_is_attacker:
                    observed.add(i)   # correct detection
            else:
                if not true_is_attacker:
                    observed.add(i)   # false positive
        return observed


class KnapsackSolver:
    """Exact DP for small problems; 2-approx (max(greedy, best-single)) otherwise."""

    @staticmethod
    def _solve_exact_dp(values: np.ndarray, costs: np.ndarray, budget: float) -> Tuple[Set[int], float]:
        n = len(values)
        if n == 0:
            return set(), 0.0
        scale = 1000
        B = int(budget * scale)
        C = np.round(costs * scale).astype(np.int64)
        if B <= 0:
            return set(), 0.0
        if B > 10**6 or n > 100:
            raise ValueError("Exact DP too large.")
        dp = np.zeros((n + 1, B + 1), dtype=np.float64)
        for i in range(1, n + 1):
            ci = C[i - 1]
            vi = values[i - 1]
            for w in range(B + 1):
                best = dp[i - 1][w]
                if ci <= w and vi > 0:
                    cand = dp[i - 1][w - ci] + vi
                    if cand > best:
                        best = cand
                dp[i][w] = best
        sel, w = set(), B
        for i in range(n, 0, -1):
            if w >= C[i - 1] and dp[i][w] > dp[i - 1][w] + 1e-9:
                sel.add(i - 1)
                w -= C[i - 1]
        return sel, sum(values[i] for i in sel)

    @staticmethod
    def _solve_greedy(values: np.ndarray, costs: np.ndarray, budget: float) -> Tuple[Set[int], float]:
        eff = np.divide(values, costs, out=np.zeros_like(values), where=costs > 0)
        order = np.argsort(eff)[::-1]
        sel, cost_acc = set(), 0.0
        for i in order:
            if costs[i] <= 0 or values[i] <= 0:
                continue
            if cost_acc + costs[i] <= budget:
                sel.add(i)
                cost_acc += costs[i]
        return sel, sum(values[i] for i in sel)

    @staticmethod
    def _solve_best_single(values: np.ndarray, costs: np.ndarray, budget: float) -> Tuple[Set[int], float]:
        best, idx = 0.0, -1
        for i, (v, c) in enumerate(zip(values, costs)):
            if c <= budget and v > best:
                best, idx = v, i
        return ({idx}, best) if idx >= 0 else (set(), 0.0)

    @staticmethod
    def solve(values: np.ndarray, costs: np.ndarray, budget: float, force_greedy: bool = False) -> Tuple[Set[int], float, str]:
        if not force_greedy:
            try:
                sel, val = KnapsackSolver._solve_exact_dp(values, costs, budget)
                return sel, val, "exact_dp_optimal"
            except Exception:
                pass

        gsel, gval = KnapsackSolver._solve_greedy(values, costs, budget)
        ssel, sval = KnapsackSolver._solve_best_single(values, costs, budget)
        if gval >= sval:
            return gsel, gval, "2approx_greedy"
        return ssel, sval, "2approx_best_single"


class RigorousMCEstimator:
    def __init__(self, epsilon: float = 0.05, delta: float = 0.10, pilot: int = 10, nmin: int = 50, nmax: int = 200):
        self.eps, self.delta, self.pilot, self.nmin, self.nmax = epsilon, delta, pilot, nmin, nmax

    def estimate(self, sample_fn, value_range: Tuple[float, float]) -> Dict:
        a, b = value_range
        pilot_vals = [sample_fn(i) for i in range(self.pilot)]
        m = float(np.mean(pilot_vals))
        var = float(np.var(pilot_vals, ddof=1)) if len(pilot_vals) > 1 else 0.0
        if var > 0:
            z = float(st.norm.ppf(1 - self.delta / 2))
            n = int(np.ceil((z * z * var) / (self.eps * self.eps)))
            n = max(self.nmin, min(self.nmax, n))
        else:
            n = self.nmin
        vals = pilot_vals + [sample_fn(i) for i in range(self.pilot, n)]
        mean = float(np.mean(vals))
        vhat = float(np.var(vals, ddof=1)) if len(vals) > 1 else 0.0
        sem = np.sqrt(vhat / max(1, len(vals)))
        if len(vals) < 30:
            lo, hi = st.t.interval(1 - self.delta, len(vals) - 1, loc=mean, scale=sem)
        else:
            z = st.norm.ppf(1 - self.delta / 2)
            lo, hi = mean - z * sem, mean + z * sem
        R = (b - a) * np.sqrt(np.log(2 / self.delta) / (2 * len(vals)))
        return {
            "estimate": mean,
            "ci": (float(lo), float(hi)),
            "n": len(vals),
            "sem": float(sem),
            "hoeffding_radius": float(R),
        }


def solve_knapsack_dp(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    n = len(values)
    if n == 0:
        return set()

    max_value = max(budget, costs.max()) if len(costs) > 0 else budget
    safe_scale_factor = min(1000, int(2**30 / max_value)) if max_value > 0 else 1000

    budget_int = int(budget * safe_scale_factor)
    costs_int = (costs * safe_scale_factor).astype(np.int64)

    memory_threshold = 50_000_000
    use_space_optimized = (n + 1) * (budget_int + 1) > memory_threshold

    if use_space_optimized:
        return _solve_knapsack_space_optimized(values, costs_int, budget_int, safe_scale_factor)
    return _solve_knapsack_full_table(values, costs_int, budget_int)


def _solve_knapsack_full_table(values: np.ndarray, costs_int: np.ndarray, budget_int: int) -> Set[int]:
    n = len(values)
    dp = np.zeros((n + 1, budget_int + 1), dtype=np.float32)

    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            dp[i][w] = dp[i - 1][w]
            if costs_int[i - 1] <= w and values[i - 1] > 0:
                dp[i][w] = max(dp[i][w], dp[i - 1][w - costs_int[i - 1]] + values[i - 1])

    selected = set()
    w = budget_int
    for i in range(n, 0, -1):
        if w >= costs_int[i - 1] and abs(dp[i][w] - (dp[i - 1][w - costs_int[i - 1]] + values[i - 1])) < 1e-9:
            selected.add(i - 1)
            w -= costs_int[i - 1]

    return selected


def _solve_knapsack_space_optimized(values: np.ndarray, costs_int: np.ndarray, budget_int: int, scale_factor: int) -> Set[int]:
    n = len(values)
    prev_row = np.zeros(budget_int + 1, dtype=np.float32)
    curr_row = np.zeros(budget_int + 1, dtype=np.float32)

    for i in range(1, n + 1):
        for w in range(budget_int + 1):
            curr_row[w] = prev_row[w]
            if costs_int[i - 1] <= w and values[i - 1] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i - 1]] + values[i - 1])
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)

    optimal_value = prev_row[budget_int]

    selected = set()
    remaining_budget = budget_int
    remaining_value = optimal_value

    for i in range(n - 1, -1, -1):
        if values[i] <= 0 or costs_int[i] > remaining_budget:
            continue
        temp_budget = remaining_budget - costs_int[i]
        temp_value = remaining_value - values[i]

        if abs(temp_value - _compute_optimal_value_without_item(
            values, costs_int, temp_budget, exclude_item=i)) < 1e-9:
            selected.add(i)
            remaining_budget = temp_budget
            remaining_value = temp_value

    return selected


def _compute_optimal_value_without_item(values: np.ndarray, costs_int: np.ndarray,
                                      budget: int, exclude_item: int) -> float:
    if budget <= 0:
        return 0.0

    prev_row = np.zeros(budget + 1, dtype=np.float32)
    curr_row = np.zeros(budget + 1, dtype=np.float32)

    for i in range(len(values)):
        if i == exclude_item:
            continue
        for w in range(budget + 1):
            curr_row[w] = prev_row[w]
            if costs_int[i] <= w and values[i] > 0:
                curr_row[w] = max(curr_row[w], prev_row[w - costs_int[i]] + values[i])
        prev_row, curr_row = curr_row, prev_row
        curr_row.fill(0)

    return prev_row[budget]


def solve_knapsack_greedy(values: np.ndarray, costs: np.ndarray, budget: float) -> Set[int]:
    efficiency = np.divide(values, costs, out=np.zeros_like(values), where=costs != 0)
    sorted_indices = np.argsort(efficiency)[::-1]

    selected = set()
    current_cost = 0.0

    for idx in sorted_indices:
        if current_cost + costs[idx] <= budget and values[idx] > 0:
            selected.add(idx)
            current_cost += costs[idx]

    return selected

# ======================== GAME THEORY CLASSES ========================

class BudgetConstrainedAttacker:
    """
    Budget-constrained attacker using partial observability and probabilistic defense.
    """

    def __init__(self, num_clients: int, damage_weights: np.ndarray,
                 attack_costs: np.ndarray, attack_budget: float,
                 partial_obs_model: PartialObservabilityModel,
                 defense_model: ProbabilisticDefenseModel,
                 estimated_defense_coverage: float = 0.5,
                 use_greedy_solver: bool = False,
                 single_attacker_mode: bool = False):
        self.n = num_clients
        self.w = damage_weights.copy()
        self.c = attack_costs.copy()
        self.B = attack_budget
        self.obs = partial_obs_model
        self.defense_model = defense_model
        self.prior = float(np.clip(estimated_defense_coverage, 0.01, 0.99))
        self.use_greedy = use_greedy_solver
        self.single_attacker_mode = single_attacker_mode

    def update_parameters(self, damage_weights: np.ndarray, attack_costs: np.ndarray,
                         attack_budget: float, estimated_defense_coverage: Optional[float] = None,
                         single_attacker_mode: Optional[bool] = None):
        self.w = damage_weights.copy()
        self.c = attack_costs.copy()
        self.B = attack_budget
        if estimated_defense_coverage is not None:
            self.prior = float(np.clip(estimated_defense_coverage, 0.01, 0.99))
        if single_attacker_mode is not None:
            self.single_attacker_mode = single_attacker_mode

    def optimal_response_to_observation(self, observations: np.ndarray) -> Set[int]:
        alpha = self.obs.observation_accuracy
        theta = self.defense_model.defense_strength
        vals = np.zeros(self.n, dtype=np.float64)

        p_obs_1 = alpha * self.prior + (1.0 - alpha) * (1.0 - self.prior)
        p_obs_0 = (1.0 - alpha) * self.prior + alpha * (1.0 - self.prior)

        for i in range(self.n):
            if observations[i] == 1:
                p_defended = (alpha * self.prior) / max(p_obs_1, 1e-12)
            else:
                p_defended = ((1.0 - alpha) * self.prior) / max(p_obs_0, 1e-12)

            p_success = (1.0 - theta) * p_defended + 1.0 * (1.0 - p_defended)
            vals[i] = self.w[i] * p_success

        mask = vals > 0
        if not np.any(mask):
            return set()

        idx = np.where(mask)[0]
        valid_vals = vals[mask]
        valid_costs = self.c[mask]

        if self.single_attacker_mode:
            best_idx = None
            best_val = -float('inf')
            for i, (v, c) in enumerate(zip(valid_vals, valid_costs)):
                if c <= self.B and v > best_val:
                    best_val = v
                    best_idx = i
            return {int(idx[best_idx])} if best_idx is not None else set()

        sel, _, _ = KnapsackSolver.solve(valid_vals, valid_costs, self.B, force_greedy=self.use_greedy)
        return {int(idx[i]) for i in sel}

    def get_attack_utilities(self, defended_clients: Set[int]) -> np.ndarray:
        utilities = np.zeros(self.n)
        for i in range(self.n):
            if i in defended_clients:
                utilities[i] = float('-inf')
            else:
                utilities[i] = self.w[i] - self.c[i]
        return utilities


class ProperStackelbergSolver:
    """
    Pure-strategy Stackelberg solver with partial observability and probabilistic defense.
    """

    def __init__(self, num_clients: int, damage_weights: np.ndarray,
                 defense_costs: np.ndarray, attack_costs: np.ndarray,
                 defense_budget: float, attack_budget: float,
                 defense_strength: float, observation_accuracy: float):
        self.n = num_clients
        self.w = damage_weights
        self.dc = defense_costs
        self.ac = attack_costs
        self.DB = defense_budget
        self.AB = attack_budget
        self.defense_model = ProbabilisticDefenseModel(defense_strength)
        self.obs = PartialObservabilityModel(observation_accuracy)

        total_defense_cost = np.sum(self.dc)
        est_coverage = self.DB / total_defense_cost if total_defense_cost > 0 else 0.0

        self.attacker = BudgetConstrainedAttacker(
            num_clients, damage_weights, attack_costs,
            attack_budget, self.obs, self.defense_model,
            estimated_defense_coverage=est_coverage,
            use_greedy_solver=True
        )

        self.mc = RigorousMCEstimator(epsilon=0.10, nmax=100)

    def _evaluate_pure_defense(self, D: Set[int], mc_estimator=None) -> float:
        if mc_estimator is None:
            mc_estimator = self.mc

        def sample_damage(seed: int) -> float:
            rng = np.random.RandomState(seed ^ 0x5F3759DF)
            observations = self.obs.observe_defense_discrete(D, self.n, seed=seed)
            attack_set = self.attacker.optimal_response_to_observation(observations)
            dmg = 0.0
            for i in attack_set:
                defended = (i in D)
                success_prob = (1.0 - self.defense_model.defense_strength) if defended else 1.0
                if rng.random() < success_prob:
                    dmg += self.w[i]
            return dmg

        result = mc_estimator.estimate(sample_fn=sample_damage, value_range=(0.0, float(np.sum(self.w))))
        return result["estimate"]

    def _generate_candidate_defenses(self, max_candidates: int = 100) -> List[Set[int]]:
        candidates = []
        ratios = self.w / np.maximum(self.dc, 1e-9)
        sorted_idx = np.argsort(ratios)[::-1]

        current = set()
        current_cost = 0.0
        for idx in sorted_idx:
            if current_cost + self.dc[idx] <= self.DB:
                current.add(int(idx))
                current_cost += self.dc[idx]
        candidates.append(current)

        rng = np.random.RandomState(42)
        for _ in range(max_candidates - 1):
            perm = rng.permutation(self.n)
            subset = set()
            cost = 0.0
            for idx in perm:
                if cost + self.dc[idx] <= self.DB:
                    subset.add(int(idx))
                    cost += self.dc[idx]
            if subset not in candidates:
                candidates.append(subset)

        return candidates

    def _enumerate_all_feasible_defenses(self) -> List[Set[int]]:
        feasible = []
        for i in range(2**self.n):
            subset = {j for j in range(self.n) if (i >> j) & 1}
            cost = sum(self.dc[j] for j in subset)
            if cost <= self.DB:
                feasible.append(subset)
        if len(feasible) == 0:
            feasible.append(set())
        return feasible

    def solve(self, use_exact: bool = True, verbose: bool = False) -> Tuple[Set[int], Dict]:
        high_precision_mc = RigorousMCEstimator(
            epsilon=0.02,
            delta=0.05,
            pilot=20,
            nmin=100,
            nmax=1000
        )

        if use_exact and self.n <= 15:
            candidates = self._enumerate_all_feasible_defenses()
            print(f"   EXACT mode: Evaluating all {len(candidates)} feasible defenses with high-precision MC")
        else:
            candidates = self._generate_candidate_defenses(max_candidates=50)
            print(f"   HEURISTIC mode: Evaluating {len(candidates)} candidate defenses with high-precision MC")

        best_D = None
        best_damage = float('inf')
        all_damages = []

        if verbose:
            print(f"   Evaluating {len(candidates)} candidate pure defense sets...")

        for idx, D in enumerate(candidates):
            damage = self._evaluate_pure_defense(D, mc_estimator=high_precision_mc)
            all_damages.append(damage)
            if damage < best_damage:
                best_damage = damage
                best_D = D
                if verbose and idx % max(1, len(candidates) // 10) == 0:
                    print(f"      Progress: {idx}/{len(candidates)}, current best damage: {best_damage:.2f}")

        all_damages = np.array(all_damages)
        damage_std = np.std(all_damages)
        damage_median = np.median(all_damages)
        damage_min = np.min(all_damages)

        info = {
            "method": "exact_high_precision" if use_exact else "heuristic_high_precision",
            "best_damage": best_damage,
            "defense_cost": sum(self.dc[i] for i in best_D),
            "num_candidates_evaluated": len(candidates),
            "damage_std": float(damage_std),
            "damage_median": float(damage_median),
            "damage_min": float(damage_min),
            "selection_bias": float(damage_median - damage_min),
        }

        if verbose:
            print(f"\n   Final Statistics:")
            print(f"      Best damage: {best_damage:.2f}")
            print(f"      Median damage: {damage_median:.2f}")
            print(f"      Std dev: {damage_std:.2f}")
            print(f"      Selection bias: {info['selection_bias']:.2f}")

        return best_D, info


class EfficientStackelbergSolver:
    """
    FIXED: Mathematically correct Stackelberg game solver using proper timeline
    Defender commits first, attacker responds optimally
    """
    
    def __init__(self, num_clients: int, damage_weights: np.ndarray, 
                 defense_costs: np.ndarray, attack_costs: np.ndarray,
                 defense_budget: float, attack_budget: float):
        self.num_clients = num_clients
        self.damage_weights = damage_weights
        self.defense_costs = defense_costs
        self.attack_costs = attack_costs
        self.defense_budget = defense_budget
        self.attack_budget = attack_budget
        
    def solve_stackelberg_exact(self) -> Set[int]:
        """
        FIXED: Exact Stackelberg solution using proper timeline
        1. Enumerate feasible defense strategies
        2. For each defense, compute attacker's optimal response  
        3. Choose defense that minimizes resulting damage
        """
        print(f"Solving Stackelberg game exactly...")
        
        # Generate feasible defense combinations
        feasible_defenses = generate_feasible_defense_combinations(
            self.defense_costs, self.defense_budget, max_combinations=5000
        )
        
        print(f"   Generated {len(feasible_defenses)} feasible defense combinations")
        
        best_defense = set()
        best_damage = float('inf')
        
        for defense_set in feasible_defenses:
            # CORRECT STACKELBERG TIMELINE:
            # 1. Defender commits to this defense
            # 2. Attacker observes and responds optimally
            attack_set = self._solve_attacker_knapsack(defense_set)
            
            # 3. Calculate resulting damage (only undefended successful attacks)
            damage = 0
            for attacker_id in attack_set:
                if attacker_id not in defense_set:  # Only undefended attacks cause damage
                    damage += self.damage_weights[attacker_id]
            if damage < best_damage:
                best_damage = damage
                best_defense = defense_set
        
        print(f"   Best defense found: {len(best_defense)} clients defended")
        print(f"   Resulting damage: {best_damage:.2f}")
        
        return best_defense
        
    def solve_stackelberg_bilevel(self) -> Set[int]:
        """
        FIXED: Bilevel programming approach with proper penalty scaling
        """
        
        def defender_objective(defense_vars):
            # Convert continuous to binary (rounding)
            defense_binary = (defense_vars > 0.5).astype(int)
            defense_set = {i for i in range(self.num_clients) if defense_binary[i] == 1}
            
            # Calculate defense cost
            cost = sum(self.defense_costs[i] for i in defense_set)
            
            # CORRECT STACKELBERG TIMELINE:
            # Get attacker's optimal response to this defense
            attack_set = self._solve_attacker_knapsack(defense_set)
            
            # Calculate resulting damage
            damage = 0
            for attacker_id in attack_set:
                if attacker_id not in defense_set:  # Only undefended attacks cause damage
                    damage += self.damage_weights[attacker_id]
            
            total_cost = damage
            
            # FIXED: Proper penalty scaling relative to damage magnitude
            if cost > self.defense_budget:
                budget_violation = cost - self.defense_budget
                # Scale penalty relative to typical damage values
                max_possible_damage = sum(self.damage_weights)
                penalty = max_possible_damage * 10 * (budget_violation / self.defense_budget)
                return total_cost + penalty
            
            return total_cost
        
        # Constraint: defense budget
        def budget_constraint(defense_vars):
            return self.defense_budget - np.dot(defense_vars, self.defense_costs)
        
        # Initial guess: greedy defense
        x0 = self._greedy_defense_vector()
        
        # Solve using constrained optimization
        constraints = {'type': 'ineq', 'fun': budget_constraint}
        bounds = [(0, 1) for _ in range(self.num_clients)]  # Binary relaxation
        
        result = minimize(
            defender_objective, x0, method='SLSQP',
            bounds=bounds, constraints=constraints,
            options={'maxiter': 200, 'disp': False}
        )
        
        if result.success:
            # Round to binary solution
            defense_binary = (result.x > 0.5).astype(int)
            defense_set = {i for i in range(self.num_clients) if defense_binary[i] == 1}
            
            # Verify budget feasibility
            actual_cost = sum(self.defense_costs[i] for i in defense_set)
            if actual_cost <= self.defense_budget:
                return defense_set
        
        # Fallback to exact method for small instances or greedy for large
        if self.num_clients < 25:
            return self.solve_stackelberg_exact()
        else:
            return self._greedy_defense_initialization()
    
    def solve_stackelberg_smart(self) -> Set[int]:
        """
        FIXED: Smart hybrid approach - exact for small instances, bilevel for large
        """
        if self.num_clients <= 20:
            # Use exact method for small instances
            return self.solve_stackelberg_exact()
        else:
            # Use bilevel optimization for larger instances
            return self.solve_stackelberg_bilevel()
    
    def _solve_attacker_knapsack(self, defended_clients: Set[int]) -> Set[int]:
        """Efficiently solve attacker's knapsack problem (unchanged - this was correct)"""
        # Calculate utilities
        utilities = np.zeros(self.num_clients)
        for i in range(self.num_clients):
            if i in defended_clients:
                # FIXED: Consistent with BudgetConstrainedAttacker - defended attacks have -inf utility
                utilities[i] = float('-inf')  # Attack completely blocked
            else:
                utilities[i] = self.damage_weights[i] - self.attack_costs[i]
        
        # Only consider profitable attacks
        profitable = utilities > 0
        if not np.any(profitable):
            return set()
        
        # Solve knapsack on profitable clients
        profitable_indices = np.where(profitable)[0]
        profitable_utilities = utilities[profitable]
        profitable_costs = self.attack_costs[profitable]
        
        # Use appropriate algorithm based on size
        if len(profitable_indices) <= 25:
            selected = solve_knapsack_dp(profitable_utilities, profitable_costs, self.attack_budget)
        else:
            selected = solve_knapsack_greedy(profitable_utilities, profitable_costs, self.attack_budget)
        
        # Map back to original indices
        return {profitable_indices[i] for i in selected}
    
    def _greedy_defense_initialization(self) -> Set[int]:
        """Generate efficient initial defense (unchanged - this was correct)"""
        efficiency = self.damage_weights / self.defense_costs
        sorted_indices = np.argsort(efficiency)[::-1]
        
        defense = set()
        current_cost = 0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                defense.add(idx)
                current_cost += self.defense_costs[idx]
        
        return defense
    
    def _greedy_defense_vector(self) -> np.ndarray:
        """Greedy initialization as vector for optimization (unchanged)"""
        efficiency = self.damage_weights / self.defense_costs
        sorted_indices = np.argsort(efficiency)[::-1]
        
        x0 = np.zeros(self.num_clients)
        current_cost = 0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                x0[idx] = 1
                current_cost += self.defense_costs[idx]
        
        return x0

class MathematicalStackelbergDefender:
    """
    FIXED: Efficient mathematical Stackelberg defender with correct timeline
    """
    
    def __init__(self, num_clients: int, damage_weights: np.ndarray, defense_costs: np.ndarray, 
                 defense_budget: float, attack_costs: np.ndarray, attack_budget: float, 
                 defense_effectiveness: float = 1.0):
        self.num_clients = num_clients
        self.damage_weights = damage_weights.copy()
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        self.defense_effectiveness = defense_effectiveness
        
        # Create fixed solver
        self.solver = EfficientStackelbergSolver(
            num_clients, damage_weights, defense_costs, 
            attack_costs, defense_budget, attack_budget
        )
        
    def solve_stackelberg_equilibrium(self) -> Set[int]:
        """
        FIXED: Solve Stackelberg game with correct timeline
        """
        start_time = time.time()
        
        # Use the smart hybrid approach
        solution = self.solver.solve_stackelberg_smart()
        
        solve_time = time.time() - start_time
        print(f"   Stackelberg solved in {solve_time:.3f}s")
        
        return solution
    
    def update_parameters(self, damage_weights: np.ndarray, defense_costs: np.ndarray, 
                         defense_budget: float, attack_costs: np.ndarray, attack_budget: float):
        """Update parameters for dynamic reshuffling"""
        self.damage_weights = damage_weights.copy()
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
        # Update solver
        self.solver = EfficientStackelbergSolver(
            self.num_clients, damage_weights, defense_costs,
            attack_costs, defense_budget, attack_budget
        )

class AttackProbabilityEstimator:
    """Estimates attack probabilities for each client based on history"""
    
    def __init__(self, num_clients: int, alpha: float = 0.3):
        self.num_clients = num_clients
        self.alpha = alpha
        self.attack_history = defaultdict(list)
        self.current_probs = np.ones(num_clients) * 0.1
        
    def update_history(self, round_num: int, attackers: Set[int]):
        """Update attack history for this round"""
        for client_id in range(self.num_clients):
            is_attacker = 1 if client_id in attackers else 0
            self.attack_history[client_id].append(is_attacker)
            
        # Update probabilities using exponential smoothing
        for client_id in range(self.num_clients):
            if len(self.attack_history[client_id]) > 0:
                recent_rate = np.mean(self.attack_history[client_id][-10:])
                self.current_probs[client_id] = (
                    self.alpha * recent_rate + 
                    (1 - self.alpha) * self.current_probs[client_id]
                )
    
    def get_probabilities(self) -> np.ndarray:
        return self.current_probs.copy()
    
    def reset_for_new_era(self):
        """Reset learning when game parameters change"""
        self.attack_history = defaultdict(list)
        self.current_probs = np.ones(self.num_clients) * 0.1

class DefenseStrategy:
    """Base class for budget-constrained defense strategies"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float):
        self.num_clients = num_clients
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float):
        """Update parameters for dynamic reshuffling"""
        self.defense_costs = defense_costs.copy()
        self.defense_budget = defense_budget
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        raise NotImplementedError

    def observe_round(self, attackers: Set[int], defended_clients: Set[int] = None):
        """Receive feedback about who attacked this round (used by bandit strategies)."""
        pass

class RandomDefense(DefenseStrategy):
    """Random defense within budget constraints"""
    
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        # Randomly select clients that fit within budget
        clients = list(range(self.num_clients))
        random.shuffle(clients)
        
        selected = set()
        current_cost = 0.0
        
        for client_id in clients:
            if current_cost + self.defense_costs[client_id] <= self.defense_budget:
                selected.add(client_id)
                current_cost += self.defense_costs[client_id]
        
        return selected

class NoDefense(DefenseStrategy):
    """No defense: defend nobody"""

    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        return set()

class HighestRiskDefense(DefenseStrategy):
    """Defend clients with highest learned attack probabilities (budget-constrained)"""
    
    def select_defended_clients(self, round_num: int, attack_probs: np.ndarray = None, **kwargs) -> Set[int]:
        if attack_probs is None:
            attack_probs = np.ones(self.num_clients) * 0.1
        
        # Sort by attack probabilities only
        sorted_indices = np.argsort(attack_probs)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class HighestValueDefense(DefenseStrategy):
    """Defend clients with highest damage potential within budget"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        # Sort by damage_weights only (highest first)
        sorted_indices = np.argsort(self.damage_weights)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class UCBDefense(DefenseStrategy):
    """UCB1 client selection: score_i = mu_i + C*sqrt(ln(t)/n_i) * damage_weight_i"""

    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float,
                 damage_weights: np.ndarray, c: float = 2.0):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        self.c = c
        self._reset_counts()

    def _reset_counts(self):
        self.attack_count = np.zeros(self.num_clients)
        self.n_obs = np.zeros(self.num_clients)
        self.t = 0

    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float,
                          damage_weights: np.ndarray = None):
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        self._reset_counts()

    def observe_round(self, attackers: Set[int], defended_clients: Set[int] = None):
        self.t += 1
        # UCB1 (Auer et al. 2002): n_i increments only for the pulled arm.
        # Here "pulling arm i" = defending client i.
        observed = defended_clients if defended_clients is not None else set(range(self.num_clients))
        for i in observed:
            self.n_obs[i] += 1
            if i in attackers:
                self.attack_count[i] += 1

    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        mu = np.where(self.n_obs > 0, self.attack_count / self.n_obs, 1.0)
        bonus = np.where(
            self.n_obs > 0,
            self.c * np.sqrt(np.log(self.t + 1) / self.n_obs),
            1.0
        )
        scores = (mu + bonus) * self.damage_weights
        sorted_idx = np.argsort(scores)[::-1]
        selected = set()
        current_cost = 0.0
        for idx in sorted_idx:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        return selected


class ThompsonSamplingDefense(DefenseStrategy):
    """Thompson Sampling client selection: sample theta_i ~ Beta(alpha_i, beta_i),
    defend by highest theta_i * damage_weight_i within budget."""

    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float,
                 damage_weights: np.ndarray):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        self._reset_counts()

    def _reset_counts(self):
        self.alpha = np.ones(self.num_clients)  # attack observations + 1
        self.beta = np.ones(self.num_clients)   # clean observations + 1

    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float,
                          damage_weights: np.ndarray = None):
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        self._reset_counts()

    def observe_round(self, attackers: Set[int], defended_clients: Set[int] = None):
        for i in range(self.num_clients):
            if i in attackers:
                self.alpha[i] += 1
            else:
                self.beta[i] += 1

    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        theta = np.random.beta(self.alpha, self.beta)
        scores = theta * self.damage_weights
        sorted_idx = np.argsort(scores)[::-1]
        selected = set()
        current_cost = 0.0
        for idx in sorted_idx:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        return selected


class NaiveStackelbergDefense(DefenseStrategy):
    """Risk-based prioritization within budget (not true game theory)"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.damage_weights = damage_weights.copy()
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, damage_weights: np.ndarray = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None:
            self.damage_weights = damage_weights.copy()
        
    def select_defended_clients(self, round_num: int, attack_probs: np.ndarray = None, **kwargs) -> Set[int]:
        if attack_probs is None:
            attack_probs = np.ones(self.num_clients) * 0.1
        
        risk_scores = attack_probs * self.damage_weights
        
        # Sort by risk_scores only (highest first)
        sorted_indices = np.argsort(risk_scores)[::-1]
        
        selected = set()
        current_cost = 0.0
        
        for idx in sorted_indices:
            if current_cost + self.defense_costs[idx] <= self.defense_budget:
                selected.add(idx)
                current_cost += self.defense_costs[idx]
        
        return selected

class StackelbergDefense(DefenseStrategy):
    """Wrapper that calls ProperStackelbergSolver once (and caches result until reshuffle)."""
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float,
                 damage_weights: np.ndarray, attack_costs: np.ndarray, attack_budget: float,
                 defense_strength: float = 0.7, observation_accuracy: float = 0.8, use_exact: bool = True):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.use_exact = use_exact
        self.observation_accuracy = observation_accuracy
        self.defense_strength = defense_strength
        self._make_solver(damage_weights, attack_costs, attack_budget)
        self._cache_defense = None
        self._dirty = True

    def _make_solver(self, damage_weights, attack_costs, attack_budget):
        self.solver = ProperStackelbergSolver(
            self.num_clients, damage_weights, self.defense_costs, attack_costs,
            self.defense_budget, attack_budget, self.defense_strength, self.observation_accuracy
        )

    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float,
                          damage_weights: np.ndarray = None, attack_costs: np.ndarray = None,
                          attack_budget: float = None):
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None and attack_costs is not None and attack_budget is not None:
            self._make_solver(damage_weights, attack_costs, attack_budget)
        else:
            self._make_solver(self.solver.w, self.solver.ac, self.solver.AB)
        self._dirty = True
        self._cache_defense = None

    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        if self._dirty or self._cache_defense is None:
            print(f"   Computing Stackelberg equilibrium (round {round_num})...")
            D, info = self.solver.solve(use_exact=self.use_exact, verbose=True)
            print(f"      Best damage: {info['best_damage']:.2f} via {info['method']}")
            self._cache_defense = D
            self._dirty = False
        return self._cache_defense

class EfficientTrueStackelbergDefense(DefenseStrategy):
    """FIXED: True game-theoretic Stackelberg defense using correct mathematical formulation"""
    
    def __init__(self, num_clients: int, defense_costs: np.ndarray, defense_budget: float, 
                 damage_weights: np.ndarray, attack_costs: np.ndarray, attack_budget: float, 
                 defense_effectiveness: float = 1.0):
        super().__init__(num_clients, defense_costs, defense_budget)
        self.defender = MathematicalStackelbergDefender(
            num_clients, damage_weights, defense_costs, defense_budget, 
            attack_costs, attack_budget, defense_effectiveness
        )
        
    def update_parameters(self, defense_costs: np.ndarray, defense_budget: float, 
                         damage_weights: np.ndarray = None, attack_costs: np.ndarray = None, 
                         attack_budget: float = None):
        """Update parameters for dynamic reshuffling"""
        super().update_parameters(defense_costs, defense_budget)
        if damage_weights is not None and attack_costs is not None and attack_budget is not None:
            self.defender.update_parameters(damage_weights, defense_costs, defense_budget, 
                                          attack_costs, attack_budget)
        
    def select_defended_clients(self, round_num: int, **kwargs) -> Set[int]:
        return self.defender.solve_stackelberg_equilibrium()

class GameTheoryIntegrator:
    """Integrates game theory logic into the FRL system"""
    
    def __init__(self, opts):
        self.opts = opts
        
        if not opts.enable_game_theory:
            return
            
        print(f"\n{'='*80}")
        print(f"INITIALIZING STACKELBERG GAME THEORY FOR FRL")
        print(f"{'='*80}")
        
        if opts.damage_scenario == "engineered":
            rng = np.random.RandomState(opts.seed)
            n = opts.num_worker
            # Define group sizes
            k_top = max(1, n // 4)
            k_mid = max(4, n // 2 - 1)
            # Base weights
            damage_weights = np.ones(n) * 8.0
            # Assign groups (shuffle to avoid index bias)
            idx = rng.permutation(n)
            top_idx = idx[:k_top]
            mid_idx = idx[k_top:k_top + k_mid]
            damage_weights[top_idx] = 55.0   # HIGH damage but attacker can't afford them
            damage_weights[mid_idx] = 30.0
            self.current_damage_weights = damage_weights
            # Defense costs: top group expensive, mid group cheaper
            defense_costs = np.ones(n) * 8.0
            defense_costs[top_idx] = 45.0
            defense_costs[mid_idx] = 10.0
            self.current_defense_costs = defense_costs
            # Engineered attack costs: top tier costs 60, exceeds attacker budget (~44)
            # so attacker NEVER targets top tier -- but UCB/TS don't know this upfront
            attack_costs = np.ones(n) * 4.0
            attack_costs[mid_idx] = 3.0
            attack_costs[top_idx] = 60.0
            self.current_attack_costs = attack_costs
            # Budgets: ratio-based (same as non-engineered / ResNet setup)
            total_defense_cost = float(np.sum(self.current_defense_costs))
            total_attack_cost = float(np.sum(self.current_attack_costs))
            self.current_defense_budget = opts.initial_defense_budget_ratio * total_defense_cost
            self.current_attack_budget = opts.initial_attack_budget_ratio * total_attack_cost
        else:
            # Create damage weights based on scenario
            self.current_damage_weights = create_realistic_unequal_damage_weights(
                opts.num_worker, opts.damage_scenario, opts.seed
            )
            # Create strategic variety in costs
            self.current_defense_costs = create_strategic_variety_costs(
                opts.num_worker, self.current_damage_weights, (0.5, 1.5), opts.seed
            )
            self.current_attack_costs = create_strategic_variety_costs(
                opts.num_worker, self.current_damage_weights, (0.5, 1.5), opts.seed + 1
            )
            # Set budgets
            total_defense_cost = sum(self.current_defense_costs)
            total_attack_cost = sum(self.current_attack_costs)
            self.current_defense_budget = opts.initial_defense_budget_ratio * total_defense_cost
            self.current_attack_budget = opts.initial_attack_budget_ratio * total_attack_cost
        
        # Store initial budgets to keep them fixed
        self.fixed_defense_budget = self.current_defense_budget
        self.fixed_attack_budget = self.current_attack_budget
        
        # Initialize observability/defense models
        defense_strength = getattr(opts, 'defense_strength', opts.defense_effectiveness)
        observation_accuracy = getattr(opts, 'observation_accuracy', 0.8)
        self.defense_model = ProbabilisticDefenseModel(defense_strength)
        self.partial_obs = PartialObservabilityModel(observation_accuracy)

        # Initialize game components
        self.prob_estimator = AttackProbabilityEstimator(opts.num_worker)
        estimated_coverage = (self.current_defense_budget / total_defense_cost) if total_defense_cost > 0 else 0.0
        self.attacker = BudgetConstrainedAttacker(
            opts.num_worker,
            self.current_damage_weights,
            self.current_attack_costs,
            self.current_attack_budget,
            self.partial_obs,
            self.defense_model,
            estimated_defense_coverage=estimated_coverage,
            use_greedy_solver=False,
            single_attacker_mode=opts.single_best_attacker_only
        )
        
        # Track eras
        self.current_era = 0
        
        # Initialize defense strategies
        self.strategies = {
            'no_defense': NoDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget),
            'random': RandomDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget),
            'highest_risk': HighestRiskDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget),
            'highest_value': HighestValueDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights),
            'naive_stackelberg': NaiveStackelbergDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights),
            'ucb': UCBDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights, c=opts.ucb_c),
            'thompson_sampling': ThompsonSamplingDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget, self.current_damage_weights),
            'true_stackelberg': StackelbergDefense(opts.num_worker, self.current_defense_costs, self.current_defense_budget,
                                                   self.current_damage_weights, self.current_attack_costs, self.current_attack_budget,
                                                   defense_strength=defense_strength, observation_accuracy=observation_accuracy)
        }
        
        # Set default strategy (allow override from opts)
        self.current_strategy = getattr(opts, 'current_strategy', 'true_stackelberg')
        
        print(f"GAME THEORY PARAMETERS:")
        print(f"   Scenario: {opts.damage_scenario}")
        print(f"   Damage weights: {self.current_damage_weights.min():.1f} to {self.current_damage_weights.max():.1f}")
        print(f"   Defense budget: {self.current_defense_budget:.1f}")
        print(f"   Attack budget: {self.current_attack_budget:.1f}")
        print(f"   Defense strength: {defense_strength:.1%}")
        print(f"   Observation accuracy: {observation_accuracy:.1%}")
        print(f"   Reshuffle frequency: every {opts.reshuffle_frequency} rounds")
        print(f"   Current strategy: {self.current_strategy}")
        
    def reshuffle_game_parameters(self, round_num: int):
        """Reshuffle game parameters periodically"""
        if not self.opts.enable_game_theory:
            return
            
        self.current_era = round_num // self.opts.reshuffle_frequency
        era_seed = self.opts.seed + self.current_era * 1000
        
        print(f"\n{'='*60}")
        print(f"STRATEGIC RESHUFFLING! Round {round_num}, Era {self.current_era}")
        print(f"{'='*60}")
        
        if self.opts.damage_scenario == "engineered":
            rng = np.random.RandomState(era_seed)
            n = self.opts.num_worker
            k_top = max(1, n // 4)
            k_mid = max(4, n // 2 - 1)
            damage_weights = np.ones(n) * 8.0
            idx = rng.permutation(n)
            top_idx = idx[:k_top]
            mid_idx = idx[k_top:k_top + k_mid]
            damage_weights[top_idx] = 55.0   # HIGH damage but attacker can't afford them
            damage_weights[mid_idx] = 30.0
            self.current_damage_weights = damage_weights
            defense_costs = np.ones(n) * 8.0
            defense_costs[top_idx] = 45.0
            defense_costs[mid_idx] = 10.0
            self.current_defense_costs = defense_costs
            # Engineered attack costs: top tier costs 60, exceeds attacker budget (~44)
            attack_costs = np.ones(n) * 4.0
            attack_costs[mid_idx] = 3.0
            attack_costs[top_idx] = 60.0
            self.current_attack_costs = attack_costs
            total_defense_cost = float(np.sum(self.current_defense_costs))
            total_attack_cost = float(np.sum(self.current_attack_costs))
            self.current_defense_budget = self.opts.initial_defense_budget_ratio * total_defense_cost
            self.current_attack_budget = self.opts.initial_attack_budget_ratio * total_attack_cost
        else:
            # Keep damage weights consistent with scenario
            self.current_damage_weights = create_realistic_unequal_damage_weights(
                self.opts.num_worker, self.opts.damage_scenario, self.opts.seed
            )
            # Create new cost variations
            self.current_defense_costs = create_strategic_variety_costs(
                self.opts.num_worker, self.current_damage_weights, (0.5, 1.5), era_seed
            )
            self.current_attack_costs = create_strategic_variety_costs(
                self.opts.num_worker, self.current_damage_weights, (0.5, 1.5), era_seed + 1
            )
            # Keep budgets fixed
            self.current_defense_budget = self.fixed_defense_budget
            self.current_attack_budget = self.fixed_attack_budget
        
        # Update game components
        estimated_coverage = (self.current_defense_budget / sum(self.current_defense_costs)) if sum(self.current_defense_costs) > 0 else 0.0
        self.attacker.update_parameters(
            self.current_damage_weights,
            self.current_attack_costs,
            self.current_attack_budget,
            estimated_defense_coverage=estimated_coverage,
            single_attacker_mode=self.attacker.single_attacker_mode
        )
        
        # Update defense strategies
        for strategy_name, strategy in self.strategies.items():
            if strategy_name in ['highest_value', 'naive_stackelberg', 'ucb', 'thompson_sampling']:
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget, self.current_damage_weights)
            elif strategy_name == 'true_stackelberg':
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget,
                                         self.current_damage_weights, self.current_attack_costs, self.current_attack_budget)
            else:
                strategy.update_parameters(self.current_defense_costs, self.current_defense_budget)
        
        # Reset learning
        self.prob_estimator.reset_for_new_era()
        
        print(f"Reshuffling complete for era {self.current_era}")
        
    def get_attack_defense_decisions(self, round_num: int):
        """Get attack and defense decisions for current round"""
        if not self.opts.enable_game_theory:
            # No game theory - return empty sets
            return set(), set()
        no_attack_threshold = 0
        if round_num >= (self.opts.max_trajectories - no_attack_threshold):
            print(f"FINAL PHASE: No attacks enabled (step {round_num}/{self.opts.max_trajectories})")
            
            # Still run defense strategy but return no attackers
            if self.current_strategy == 'true_stackelberg':
                defended_clients = self.strategies[self.current_strategy].select_defended_clients(round_num)
            else:
                attack_probs = self.prob_estimator.get_probabilities()
                defended_clients = self.strategies[self.current_strategy].select_defended_clients(
                    round_num, attack_probs=attack_probs
                )
            
            # Return empty attackers set (no attacks in final phase)
            return set(), defended_clients
        # Reshuffle if needed
        if round_num > 0 and round_num % self.opts.reshuffle_frequency == 0:
            self.reshuffle_game_parameters(round_num)
        
        # STEP 1: DEFENDER MOVES FIRST (Stackelberg timeline)
        if self.current_strategy == 'true_stackelberg':
            defended_clients = self.strategies[self.current_strategy].select_defended_clients(round_num)
        else:
            attack_probs = self.prob_estimator.get_probabilities()
            defended_clients = self.strategies[self.current_strategy].select_defended_clients(
                round_num, attack_probs=attack_probs
            )
        
        # STEP 2: ATTACKER OBSERVES (NOISY) AND RESPONDS OPTIMALLY
        observations = self.partial_obs.observe_defense_discrete(
            defended_clients, self.opts.num_worker, seed=round_num + self.opts.seed
        )
        attackers = self.attacker.optimal_response_to_observation(observations)
        
        # Noisy defender observation of who attacked (shared by all non-Stackelberg methods)
        noisy_attackers = self.partial_obs.observe_attack_discrete(
            attackers, self.opts.num_worker, seed=round_num + self.opts.seed + 9999
        )

        # Update probability estimates with noisy observations (consistent with partial observability)
        self.prob_estimator.update_history(round_num, noisy_attackers)

        # Feed back to bandit strategies using the same noisy observations
        for strategy in self.strategies.values():
            strategy.observe_round(noisy_attackers, defended_clients)

        return attackers, defended_clients
    
    def calculate_damage_metrics(self, round_num: int, attackers: Set[int], defended_clients: Set[int],
                               baseline_performance: float, actual_performance: float,
                               successful_attackers: Optional[Set[int]] = None):
        """Calculate game theory damage metrics"""
        if not self.opts.enable_game_theory:
            return {}
        
        # Calculate performance drop
        performance_drop = max(0, baseline_performance - actual_performance)
        
        # Calculate theoretical/realized damage based on damage weights
        if successful_attackers is not None:
            theoretical_damage = sum(self.current_damage_weights[i] for i in successful_attackers)
        else:
            defense_strength = self.defense_model.defense_strength
            theoretical_damage = 0.0
            for i in attackers:
                if i in defended_clients:
                    theoretical_damage += (1.0 - defense_strength) * self.current_damage_weights[i]
                else:
                    theoretical_damage += self.current_damage_weights[i]
        
        # Calculate costs
        defense_cost_used = sum(self.current_defense_costs[i] for i in defended_clients)
        attack_cost_used = sum(self.current_attack_costs[i] for i in attackers)
        
        # Get attack utilities
        attack_utilities = self.attacker.get_attack_utilities(defended_clients)
        avg_attacker_utility = np.mean([attack_utilities[i] for i in attackers]) if attackers else 0
        
        successful_attacks = len(successful_attackers) if successful_attackers is not None else len(attackers)
        blocked_attacks = (len(attackers) - successful_attacks) if successful_attackers is not None else len(attackers & defended_clients)
        return {
            'performance_drop': performance_drop,
            'theoretical_damage': theoretical_damage,
            'num_attackers': len(attackers),
            'num_defended': len(defended_clients),
            'successful_attacks': successful_attacks,
            'blocked_attacks': blocked_attacks,
            'deterred_attacks': len(defended_clients),
            'avg_attacker_utility': avg_attacker_utility,
            'defense_cost_used': defense_cost_used,
            'attack_cost_used': attack_cost_used,
            'defense_budget_utilization': defense_cost_used / self.current_defense_budget if self.current_defense_budget > 0 else 0,
            'attack_budget_utilization': attack_cost_used / self.current_attack_budget if self.current_attack_budget > 0 else 0,
            'current_era': self.current_era,
            'era_round': round_num % self.opts.reshuffle_frequency,
            'deterrence_efficiency': len(defended_clients) / defense_cost_used if defense_cost_used > 0 else 0,
            'damage_weight_range': f"{self.current_damage_weights.min():.1f}-{self.current_damage_weights.max():.1f}",
            'current_strategy': self.current_strategy
        }

# =============================================================================
# POLICY NETWORKS
# =============================================================================

def mlp(sizes, activation=nn.Tanh, output_activation=nn.Identity):
    """Build a feedforward neural network."""
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)

class MlpPolicy(nn.Module):
    def __init__(self, sizes, activation='Tanh', output_activation='Identity'):
        super(MlpPolicy, self).__init__()
        
        # Store parameters
        self.activation = activation
        self.output_activation = output_activation
        
        if activation == 'Tanh':
            self.activation = nn.Tanh
        elif activation == 'ReLU':
            self.activation = nn.ReLU
        else:
            raise NotImplementedError
            
        if output_activation == 'Identity':
            self.output_activation = nn.Identity
        elif output_activation == 'Tanh':
            self.output_activation = nn.Tanh
        elif output_activation == 'ReLU':
            self.output_activation = nn.ReLU
        elif output_activation == 'Softmax':
            self.output_activation = nn.Softmax
        else:
            raise NotImplementedError
            
        # Make policy network
        self.sizes = sizes
        self.logits_net = mlp(self.sizes, self.activation, self.output_activation)
        
        # Init parameters
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, obs, sample=True, fixed_action=None):
        obs = obs.view(-1)
        
        # Forward pass the policy net
        logits = self.logits_net(obs)
        
        # Get the policy dist
        policy = Categorical(logits=logits)
        
        # Take the pre-set action if given
        if fixed_action is not None:
            action = torch.tensor(fixed_action, device=obs.device)
        # Take random action
        elif sample:
            try:
                action = policy.sample()
            except:
                print(logits, obs)
        # Take greedy action
        else:
            action = policy.probs.argmax()
        
        return action.item(), policy.log_prob(action)

class DiagonalGaussianMlpPolicy(nn.Module):
    def __init__(self, sizes, activation='Tanh', output_activation='Tanh', geer=1):
        super(DiagonalGaussianMlpPolicy, self).__init__()

        # Store parameters
        self.activation = activation
        self.output_activation = output_activation

        if activation == 'Tanh':
            self.activation = nn.Tanh
        elif activation == 'ReLU':
            self.activation = nn.ReLU
        else:
            raise NotImplementedError

        # Make policy network
        self.sizes = sizes
        self.geer = geer
        self.logits_net = mlp(self.sizes[:-1], self.activation, nn.Identity)
        self.mu_net = nn.Linear(self.sizes[-2], self.sizes[-1], bias=False)
        self.log_sigma_net = nn.Linear(self.sizes[-2], self.sizes[-1], bias=False)
        self.LOG_SIGMA_MIN = -20
        self.LOG_SIGMA_MAX = -2
        
        # Init parameters
        self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / math.sqrt(param.size(-1))
            param.data.uniform_(-stdv, stdv)

    def forward(self, obs, sample=True, fixed_action=None):
        # Forward pass the policy net
        logits = self.logits_net(obs)

        # Get the mu
        mu = torch.tanh(self.mu_net(logits)) * self.geer

        # Get the sigma
        sigma = torch.tanh(torch.clamp(self.log_sigma_net(logits), self.LOG_SIGMA_MIN, self.LOG_SIGMA_MAX).exp())

        # Get the policy dist
        policy = Normal(mu, sigma)

        # Take the pre-set action
        if fixed_action is not None:
            action = torch.tensor(fixed_action, device=obs.device)
        else:
            if sample:
                action = policy.sample()
            else:
                action = mu.detach()
        
        # Avoid NaN
        ll = policy.log_prob(action)
        ll[ll < -1e5] = -1e5
        
        return action.numpy(), ll.sum()

# =============================================================================
# ENHANCED WORKER CLASS WITH GAME THEORY
# =============================================================================

class Worker:
    def __init__(self, id, env_name, hidden_units, gamma, 
                 activation='Tanh', output_activation='Identity', 
                 max_epi_len=0, opts=None):
        super(Worker, self).__init__()
        
        # Setup
        self.id = id
        self.gamma = gamma
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.max_epi_len = max_epi_len
        
        assert opts is not None
        self.opts = opts
        
        # Game theory attributes (only used when game theory is enabled)
        self.damage_weight = 1.0  
        self.is_currently_attacked = False
        self.is_currently_defended = False
        self.attack_succeeded = False
        self.attack_intensity = 0.0
        
        # Get observation dim
        obs_dim = self.env.observation_space.shape[0]
        if isinstance(self.env.action_space, Discrete):
            n_acts = self.env.action_space.n
        else:
            n_acts = self.env.action_space.shape[0]
        
        hidden_sizes = list(eval(hidden_units))
        self.sizes = [obs_dim] + hidden_sizes + [n_acts]
        
        # Get policy net
        if isinstance(self.env.action_space, Discrete):
            self.logits_net = MlpPolicy(self.sizes, activation, output_activation)
        else:
            self.logits_net = DiagonalGaussianMlpPolicy(self.sizes, activation)
        
        if self.id == 1:
            print(self.logits_net)

    def update_game_status(self, damage_weight: float, is_attacked: bool, is_defended: bool,
                           attack_intensity: float, attack_succeeded: bool = False):
        """Update game theory status for this worker"""
        self.damage_weight = damage_weight
        self.is_currently_attacked = is_attacked
        self.is_currently_defended = is_defended
        self.attack_intensity = attack_intensity
        self.attack_succeeded = attack_succeeded

    def load_param_from_master(self, param):
        model_actor = get_inner_model(self.logits_net)
        model_actor.load_state_dict({**model_actor.state_dict(), **param})

    def rollout(self, device, max_steps=1000, render=False, env=None, obs=None, 
                sample=True, mode='human', save_dir='./', filename='.'):
        
        if env is None and obs is None:
            env = self.env
            obs = env.reset()
            if isinstance(obs, tuple):
                obs = obs[0]
            
        done = False  
        ep_rew = []
        frames = []
        step = 0
        while not done and step < max_steps:
            step += 1
            if render:
                if mode == 'rgb':
                    frames.append(env.render(mode="rgb_array"))
                else:
                    env.render()
                
            obs = env_wrapper(env.unwrapped.spec.id, obs)
            action = self.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(device), sample=sample)[0]
            step_out = env.step(action)
            if len(step_out) == 5:
                obs, rew, terminated, truncated, _ = step_out
                done = terminated or truncated
            else:
                obs, rew, done, _ = step_out
            ep_rew.append(rew)

        if mode == 'rgb': 
            save_frames_as_gif(frames, save_dir, filename)
        return np.sum(ep_rew), len(ep_rew), ep_rew
    
    def collect_experience_for_training(self, B, device, record=False, sample=True):
        # Make some empty lists for logging
        batch_weights = []
        batch_rets = []
        batch_lens = []
        batch_log_prob = []

        # Reset episode-specific variables
        obs = self.env.reset()
        if isinstance(obs, tuple):
            obs = obs[0]
        done = False
        ep_rews = []
        
        # Make two lists for recording the trajectory
        if record:
            batch_states = []
            batch_actions = []

        t = 1
        # Collect experience by acting in the environment with current policy
        while True:
            # Save trajectory
            if record:
                batch_states.append(obs)
            
            # Act in the environment  
            obs = env_wrapper(self.env_name, obs)
            
            act, log_prob = self.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(device), sample=sample)
            # Apply attack at interaction level (more effective for RL)
            if self.opts.enable_game_theory and self.attack_succeeded:
                dmg_scale = self.damage_weight / 10.0
                if self.opts.attack_style == "action_flip" and isinstance(self.env.action_space, Discrete):
                    scaled_flip_prob = min(1.0, self.opts.attack_flip_prob * dmg_scale)
                    if np.random.rand() < scaled_flip_prob:
                        act = 1 - int(act)  # CartPole: flip between 0 and 1
                elif self.opts.attack_style == "obs_noise":
                    scaled_std = self.opts.obs_noise_std * dmg_scale
                    noise = np.random.randn(*np.array(obs).shape) * scaled_std
                    obs = (np.array(obs) + noise).astype(np.float32)
            
            step_out = self.env.step(act)
            if len(step_out) == 5:
                obs, rew, terminated, truncated, info = step_out
                done = terminated or truncated
            else:
                obs, rew, done, info = step_out

            if self.opts.enable_game_theory and self.attack_succeeded:
                if self.opts.attack_style == "reward_poison":
                    dmg_scale = self.damage_weight / 10.0
                    scaled_poison = min(2.0, self.opts.reward_poison_scale * dmg_scale)
                    rew = rew * (1.0 - scaled_poison)
                        
            # Timestep
            t = t + 1
            
            # Save action_log_prob, reward
            batch_log_prob.append(log_prob)
            ep_rews.append(rew)
            
            # Save trajectory
            if record:
                batch_actions.append(act)

            if done or len(ep_rews) >= self.max_epi_len:
                # If episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)
                
                # The weight for each logprob(a_t|s_T) is sum_t^T (gamma^(t'-t) * r_t')
                returns = []
                R = 0
                
                # Normal reward processing
                for r in ep_rews[::-1]:
                    R = r + self.gamma * R
                    returns.insert(0, R)            
                returns = torch.tensor(returns, dtype=torch.float32)
                
                # Return whitening
                advantage = (returns - returns.mean()) / (returns.std() + 1e-20)
                batch_weights += advantage

                # End experience loop if we have enough of it
                if len(batch_lens) >= B:
                    break
                
                # Reset episode-specific variables
                obs = self.env.reset()
                if isinstance(obs, tuple):
                    obs = obs[0]
                done, ep_rews, t = False, [], 1

        # Make torch tensor and restrict to batch_size
        weights = torch.as_tensor(batch_weights, dtype=torch.float32).to(device)
        logp = torch.stack(batch_log_prob)

        if record:
            return weights, logp, batch_rets, batch_lens, batch_states, batch_actions
        else:
            return weights, logp, batch_rets, batch_lens
    
    def train_one_epoch(self, B, device, sample):
        # Collect experience by acting in the environment with current policy
        weights, logp, batch_rets, batch_lens = self.collect_experience_for_training(
            B, device, sample=sample)
        
        # Calculate policy gradient loss
        batch_loss = -(logp * weights).mean()
    
        # Take a single policy gradient update step
        self.logits_net.zero_grad()
        batch_loss.backward()
        
        # Apply attack if this worker is strategically selected as attacker and attack succeeds
        if self.opts.enable_game_theory and self.attack_succeeded and self.opts.attack_style == "grad_noise":
            # Apply random noise attack scaled by damage weight (higher value = stronger attack)
            grad = []
            for item in self.parameters():
                # Scale attack strength by damage weight: base_multiplier * (damage_weight / reference_weight)
                reference_weight = 10.0  # Reference damage weight for scaling
                attack_multiplier = self.opts.base_attack_intensity * (self.damage_weight / reference_weight)  # Original was * 3
                
                # Original FRL random noise attack with damage-weight scaling
                rnd = (torch.rand(item.grad.shape, device=item.device) * 2 - 1) * \
                      (item.grad.max().data - item.grad.min().data) * attack_multiplier
                grad.append(item.grad + rnd)
        else:
            # Return true gradient (normal behavior)
            grad = [item.grad for item in self.parameters()]

        # Report the results to the agent for training purpose
        return grad, batch_loss.item(), np.mean(batch_rets), np.mean(batch_lens)

    def _flatten_worker_grad(self, grad_list):
        flats = []
        for g in grad_list:
            flats.append(g.detach().view(-1).cpu())
        return torch.cat(flats)

    def _unflatten_to_params(self, flat_vec, template_grads):
        out = []
        offset = 0
        for g in template_grads:
            numel = g.numel()
            out.append(flat_vec[offset:offset + numel].view_as(g).to(g.device))
            offset += numel
        return out

    def _aggregate_huber(self, gradients):
        grad_mat = []
        for i in range(self.world_size):
            grad_mat.append(self._flatten_worker_grad(gradients[i]).numpy())
        grad_mat = np.stack(grad_mat, axis=0)
        mu = grad_mat.mean(axis=0)
        for _ in range(self.opts.huber_iters):
            diffs = grad_mat - mu
            dists = np.linalg.norm(diffs, axis=1)
            med = np.median(dists)
            mad = np.median(np.abs(dists - med)) + 1e-12
            tau = med + self.opts.huber_mad_k * mad
            weights = np.minimum(1.0, tau / (dists + 1e-12))
            mu = (weights[:, None] * grad_mat).sum(axis=0) / (weights.sum() + 1e-12)
        return self._unflatten_to_params(torch.tensor(mu, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_foundationfl(self, gradients):
        grad_mat = []
        for i in range(self.world_size):
            grad_mat.append(self._flatten_worker_grad(gradients[i]).numpy())
        grad_mat = np.stack(grad_mat, axis=0)
        center = np.median(grad_mat, axis=0)
        diffs = grad_mat - center
        dist = np.linalg.norm(diffs, axis=1)
        med = np.median(dist)
        mad = np.median(np.abs(dist - med)) + 1e-12
        synth_n = int(np.ceil(self.world_size * self.opts.foundationfl_synth_ratio))
        noise_scale = self.opts.foundationfl_noise_scale * (mad + 1e-12)
        synth = center + np.random.randn(synth_n, grad_mat.shape[1]) * noise_scale
        aug = np.concatenate([grad_mat, synth], axis=0)
        # Trimmed mean aggregation
        trim = self.opts.foundationfl_trim_ratio
        if trim > 0:
            k = int(np.floor(trim * aug.shape[0]))
            sorted_aug = np.sort(aug, axis=0)
            trimmed = sorted_aug[k:sorted_aug.shape[0] - k, :]
            agg = trimmed.mean(axis=0)
        else:
            agg = np.median(aug, axis=0)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_fltg(self, gradients):
        # FLTG (Wen et al., arXiv:2505.12851): ReLU-clipped cosine similarity trust scores.
        # In FRL (no labeled root dataset), the coordinate-wise median of received
        # gradients serves as a Byzantine-robust trusted reference proxy.
        grad_mat = np.stack(
            [self._flatten_worker_grad(gradients[i]).numpy() for i in range(self.world_size)],
            axis=0)
        trusted_flat = np.median(grad_mat, axis=0)
        trusted_norm = np.linalg.norm(trusted_flat) + 1e-12
        norms = np.linalg.norm(grad_mat, axis=1) + 1e-12
        cos = (grad_mat @ trusted_flat) / (norms * trusted_norm)
        ts = np.maximum(0.0, cos)
        grad_clipped = (trusted_norm / norms)[:, None] * grad_mat
        if ts.sum() == 0:
            agg = grad_mat.mean(axis=0)
        else:
            agg = (ts[:, None] * grad_clipped).sum(axis=0) / (ts.sum() + 1e-12)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_fedgreed(self, gradients):
        # FedGreed (Kritharakis et al., arXiv:2508.18060): rank clients by distance
        # from a trusted reference and greedily select the closest fraction for aggregation.
        # In FRL (no labeled reference dataset), the coordinate-wise median serves
        # as a Byzantine-robust reference proxy.
        grad_mat = np.stack(
            [self._flatten_worker_grad(gradients[i]).numpy() for i in range(self.world_size)],
            axis=0)
        reference = np.median(grad_mat, axis=0)
        dists = np.linalg.norm(grad_mat - reference, axis=1)
        n_select = max(1, int(self.opts.fedgreed_fraction * self.world_size))
        selected = np.argsort(dists)[:n_select]
        agg = grad_mat[selected].mean(axis=0)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_with_defense(self, gradients):
        if self.opts.aggregation_defense == 'fltg':
            return self._aggregate_fltg(gradients)
        if self.opts.aggregation_defense == 'fedgreed':
            return self._aggregate_fedgreed(gradients)
        return None

    def to(self, device):
        self.logits_net.to(device)
        return self

    def eval(self):
        self.logits_net.eval()
        return self

    def train(self):
        self.logits_net.train()
        return self

    def parameters(self):
        return self.logits_net.parameters()

# =============================================================================
# AGENT AND MEMORY CLASSES WITH GAME THEORY
# =============================================================================

class Memory:
    def __init__(self):
        self.steps = {}
        self.eval_values = {}
        self.training_values = {}
        self.game_theory_metrics = {}
        self.debug_summary = {}

def worker_run(worker, param, opts, Batch_size, seed):
    # Distribute current parameters
    worker.load_param_from_master(param)
    try:
        worker.env.reset(seed=seed)
    except TypeError:
        if hasattr(worker.env, "seed"):
            worker.env.seed(seed)
    try:
        if hasattr(worker.env, "action_space") and hasattr(worker.env.action_space, "seed"):
            worker.env.action_space.seed(seed)
    except Exception:
        pass
    
    # Get returned gradients and info from all agents        
    out = worker.train_one_epoch(Batch_size, opts.device, opts.do_sample_for_training)
    
    # Store all values
    return out

class Agent:
    def __init__(self, opts):
        # Figure out the options
        self.opts = opts
        
        # Initialize game theory integrator
        self.game_integrator = GameTheoryIntegrator(opts)
        
        # Setup arrays for distributed RL
        self.world_size = opts.num_worker
        
        # Figure out the master
        self.master = Worker(
            id=0,
            env_name=opts.env_name,
            gamma=opts.gamma,
            hidden_units=opts.hidden_units, 
            activation=opts.activation, 
            output_activation=opts.output_activation,
            max_epi_len=opts.max_epi_len,
            opts=opts
        ).to(opts.device)
        
        # Figure out a copy of the master node for importance sampling purpose
        self.old_master = Worker(
            id=-1,
            env_name=opts.env_name,
            gamma=opts.gamma,
            hidden_units=opts.hidden_units, 
            activation=opts.activation, 
            output_activation=opts.output_activation,
            max_epi_len=opts.max_epi_len,
            opts=opts
        ).to(opts.device)
        
        # Figure out all the workers (all honest by default)
        self.workers = []
        for i in range(self.world_size):
            worker = Worker(
                id=i+1,
                env_name=opts.env_name,
                gamma=opts.gamma,
                hidden_units=opts.hidden_units, 
                activation=opts.activation, 
                output_activation=opts.output_activation,
                max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Set initial game theory status (all honest initially)
            if opts.enable_game_theory and hasattr(self.game_integrator, 'current_damage_weights'):
                worker.update_game_status(
                    damage_weight=self.game_integrator.current_damage_weights[i],
                    is_attacked=False,
                    is_defended=False,
                    attack_intensity=0.0,
                    attack_succeeded=False
                )
            
            self.workers.append(worker)
        
        print(f'{opts.num_worker} workers initialized (all honest by default).')
        if opts.enable_game_theory:
            print(f'Game theory integration enabled with {opts.damage_scenario} scenario.')
        
        if not opts.eval_only:
            # Figure out the optimizer
            self.optimizer = optim.Adam(self.master.logits_net.parameters(), lr=opts.lr_model)
        
        self.pool = Pool(self.world_size)
        self.memory = Memory()

    def _flatten_worker_grad(self, grad_list):
        flats = []
        for g in grad_list:
            flats.append(g.detach().view(-1).cpu())
        return torch.cat(flats)

    def _unflatten_to_params(self, flat_vec, template_grads):
        out = []
        offset = 0
        for g in template_grads:
            numel = g.numel()
            out.append(flat_vec[offset:offset + numel].view_as(g).to(g.device))
            offset += numel
        return out

    def _aggregate_huber(self, gradients):
        grad_mat = []
        for i in range(self.world_size):
            grad_mat.append(self._flatten_worker_grad(gradients[i]).numpy())
        grad_mat = np.stack(grad_mat, axis=0)
        mu = grad_mat.mean(axis=0)
        for _ in range(self.opts.huber_iters):
            diffs = grad_mat - mu
            dists = np.linalg.norm(diffs, axis=1)
            med = np.median(dists)
            mad = np.median(np.abs(dists - med)) + 1e-12
            tau = med + self.opts.huber_mad_k * mad
            weights = np.minimum(1.0, tau / (dists + 1e-12))
            mu = (weights[:, None] * grad_mat).sum(axis=0) / (weights.sum() + 1e-12)
        return self._unflatten_to_params(torch.tensor(mu, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_foundationfl(self, gradients):
        grad_mat = []
        for i in range(self.world_size):
            grad_mat.append(self._flatten_worker_grad(gradients[i]).numpy())
        grad_mat = np.stack(grad_mat, axis=0)
        center = np.median(grad_mat, axis=0)
        diffs = grad_mat - center
        dist = np.linalg.norm(diffs, axis=1)
        med = np.median(dist)
        mad = np.median(np.abs(dist - med)) + 1e-12
        synth_n = int(np.ceil(self.world_size * self.opts.foundationfl_synth_ratio))
        noise_scale = self.opts.foundationfl_noise_scale * (mad + 1e-12)
        synth = center + np.random.randn(synth_n, grad_mat.shape[1]) * noise_scale
        aug = np.concatenate([grad_mat, synth], axis=0)
        trim = self.opts.foundationfl_trim_ratio
        if trim > 0:
            k = int(np.floor(trim * aug.shape[0]))
            sorted_aug = np.sort(aug, axis=0)
            trimmed = sorted_aug[k:sorted_aug.shape[0] - k, :]
            agg = trimmed.mean(axis=0)
        else:
            agg = np.median(aug, axis=0)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_fltg(self, gradients):
        # FLTG (Wen et al., arXiv:2505.12851): ReLU-clipped cosine similarity trust scores
        # with magnitude normalisation. In FRL (no labeled root dataset), the coordinate-wise
        # median of received gradients serves as a Byzantine-robust trusted reference proxy.
        grad_mat = np.stack(
            [self._flatten_worker_grad(gradients[i]).numpy() for i in range(self.world_size)],
            axis=0)
        trusted_flat = np.median(grad_mat, axis=0)
        trusted_norm = np.linalg.norm(trusted_flat) + 1e-12
        norms = np.linalg.norm(grad_mat, axis=1) + 1e-12
        # Trust score = ReLU(cosine similarity with trusted reference)
        cos = (grad_mat @ trusted_flat) / (norms * trusted_norm)
        ts = np.maximum(0.0, cos)
        # Magnitude clipping: normalise each client gradient to the trusted norm
        grad_clipped = (trusted_norm / norms)[:, None] * grad_mat
        if ts.sum() == 0:
            agg = grad_mat.mean(axis=0)
        else:
            agg = (ts[:, None] * grad_clipped).sum(axis=0) / (ts.sum() + 1e-12)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_fedgreed(self, gradients):
        # FedGreed (Kritharakis et al., arXiv:2508.18060): rank clients by distance
        # from a trusted reference and greedily select the closest fraction for aggregation.
        # In FRL (no labeled reference dataset), the coordinate-wise median serves
        # as a Byzantine-robust reference proxy.
        grad_mat = np.stack(
            [self._flatten_worker_grad(gradients[i]).numpy() for i in range(self.world_size)],
            axis=0)
        reference = np.median(grad_mat, axis=0)
        dists = np.linalg.norm(grad_mat - reference, axis=1)
        n_select = max(1, int(self.opts.fedgreed_fraction * self.world_size))
        selected = np.argsort(dists)[:n_select]
        agg = grad_mat[selected].mean(axis=0)
        return self._unflatten_to_params(torch.tensor(agg, dtype=gradients[0][0].dtype), gradients[0])

    def _aggregate_with_defense(self, gradients):
        if self.opts.aggregation_defense == 'fltg':
            return self._aggregate_fltg(gradients)
        if self.opts.aggregation_defense == 'fedgreed':
            return self._aggregate_fedgreed(gradients)
        return None

    def load(self, load_path):
        assert load_path is not None
        load_data = torch_load_cpu(load_path)
        # Load data for actor
        model_actor = get_inner_model(self.master.logits_net)
        model_actor.load_state_dict({**model_actor.state_dict(), **load_data.get('master', {})})
        
        if not self.opts.eval_only:
            # Load data for optimizer
            self.optimizer.load_state_dict(load_data['optimizer'])
            for state in self.optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(self.opts.device)
            # Load data for torch and cuda
            torch.set_rng_state(load_data['rng_state'])
            if self.opts.use_cuda:
                torch.cuda.set_rng_state_all(load_data['cuda_rng_state'])
    
        print(' [*] Loading data from {}'.format(load_path))
        
    def save(self, epoch, run_id):
        print('Saving model and state...')
        torch.save(
            {
                'master': get_inner_model(self.master.logits_net).state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'rng_state': torch.get_rng_state(),
                'cuda_rng_state': torch.cuda.get_rng_state_all(),
            },
            os.path.join(self.opts.save_dir, 'r{}-epoch-{}.pt'.format(run_id, epoch))
        )
    def save_performance_to_csv(self, strategy_name='default'):
        """Save current performance data to CSV"""
        csv_saver = PerformanceCSVSaver(self.opts.save_dir if self.opts.save_dir else './results')
        
        memory_data = {
            'steps': self.memory.steps,
            'training_values': self.memory.training_values,
            'eval_values': self.memory.eval_values
        }
        
        game_metrics = getattr(self.memory, 'game_theory_metrics', None)
        csv_saver.save_strategy_performance_csv(strategy_name, memory_data, game_metrics)

    def eval(self):
        # Turn model to eval mode
        self.master.eval()
        
    def train(self):
        # Turn model to training mode
        self.master.train()
        
    def start_training(self, tb_logger=None, run_id=None):
        # Parameters of running
        opts = self.opts

        # For storing number of trajectories sampled
        step = 0
        epoch = 0
        ratios_step = 0
        
        # Store baseline performance for game theory damage calculation
        baseline_performance = 0
        debug_lines = []
        debug_records = []
        debug_rounds = []
        
        # Start the training loop
        while step <= opts.max_trajectories:
            # Epoch for storing checkpoints of model
            epoch += 1
            
            # Turn model into training mode
            print('\n\n')
            print("|", format(f" Training step {step} run_id {run_id} in {opts.seeds}", "*^60"), "|")
            self.train()
            
            # Get game theory decisions
            attackers = set()
            defended_clients = set()
            successful_attackers = set()
            if opts.enable_game_theory:
                attackers, defended_clients = self.game_integrator.get_attack_defense_decisions(step)
                print(f"Game Theory Round {step}: {len(attackers)} attackers, {len(defended_clients)} defended")
                
                # Update worker game status
                for i, worker in enumerate(self.workers):
                    is_attacked = i in attackers
                    is_defended = i in defended_clients
                    damage_weight = self.game_integrator.current_damage_weights[i]
                    attack_succeeded = False
                    if is_attacked:
                        attack_succeeded = self.game_integrator.defense_model.sample_attack_outcome(
                            is_defended, seed=(step + 1) * 1000 + i
                        )
                        if attack_succeeded:
                            successful_attackers.add(i)
                    
                    # Scale attack intensity by damage weight for game theory
                    scaled_attack_intensity = opts.base_attack_intensity * (damage_weight / 10.0)
                    
                    worker.update_game_status(
                        damage_weight=damage_weight,
                        is_attacked=is_attacked,
                        is_defended=is_defended,
                        attack_intensity=scaled_attack_intensity,
                        attack_succeeded=attack_succeeded
                    )
                if getattr(opts, "debug_summary", False) and step < getattr(opts, "debug_summary_rounds", 0):
                    overlap = attackers & defended_clients
                    debug_lines.append(
                        f"round={step} attackers={sorted(list(attackers))} "
                        f"defended={sorted(list(defended_clients))} "
                        f"overlap={sorted(list(overlap))} "
                        f"successful={sorted(list(successful_attackers))}"
                    )
                    debug_records.append({
                        "round": int(step),
                        "attackers": sorted(list(attackers)),
                        "defended": sorted(list(defended_clients)),
                        "overlap": sorted(list(overlap)),
                        "successful": sorted(list(successful_attackers)),
                        "defense_budget": float(self.game_integrator.current_defense_budget),
                        "attack_budget": float(self.game_integrator.current_attack_budget),
                        "defense_costs": self.game_integrator.current_defense_costs.tolist(),
                        "attack_costs": self.game_integrator.current_attack_costs.tolist(),
                        "damage_weights": self.game_integrator.current_damage_weights.tolist(),
                        "current_strategy": self.game_integrator.current_strategy,
                    })
            
            # Setup lr_scheduler
            print("Training with lr={:.3e}".format(self.optimizer.param_groups[0]['lr']), flush=True)
            
            # Some empty list for training and logging purpose
            gradient = []
            batch_loss = []
            batch_rets = []
            batch_lens = []
            
            # Distribute current params and Batch_Size to all workers
            param = get_inner_model(self.master.logits_net).state_dict()
            
            if opts.FedPG_BR:
                Batch_size = np.random.randint(opts.Bmin, opts.Bmax + 1)
            else:
                Batch_size = opts.B
        
            seeds = np.random.randint(1, 100000, self.world_size).tolist()
            args = zip(self.workers, repeat(param), repeat(opts), repeat(Batch_size), seeds)
            
            results = self.pool.starmap(worker_run, args)

            # Collect the gradient(for training), loss(for logging only), returns(for logging only), and epi_length(for logging only) from workers         
            for out in tqdm(results, desc='Worker node'):
                grad, loss, rets, lens = out
                
                # Store all values
                gradient.append(grad)
                batch_loss.append(loss)
                batch_rets.append(rets)
                batch_lens.append(lens)
            
            # Simulate FedScsPG-attack (if needed) on server for demo
            if opts.attack_type == 'FedScsPG-attack' and opts.num_Byzantine > 0:  
                for idx, _ in enumerate(self.master.parameters()):
                    tmp = []
                    for bad_worker in range(opts.num_Byzantine):
                        tmp.append(gradient[bad_worker][idx].view(-1))
                    tmp = torch.stack(tmp)

                    estimated_2sigma = euclidean_dist(tmp, tmp).max()
                    estimated_mean = tmp.mean(0)
                    
                    # Change the gradient to be estimated_mean + 3sigma (with a random direction rnd)
                    rnd = torch.rand(gradient[0][idx].shape) * 2. - 1.
                    rnd = rnd / rnd.norm()
                    attacked_gradient = estimated_mean.view(gradient[bad_worker][idx].shape) + rnd * estimated_2sigma * 3. / 2.
                    for bad_worker in range(opts.num_Byzantine):
                        gradient[bad_worker][idx] = attacked_gradient
                        
            # Make the old policy as a copy of the current master node
            self.old_master.load_param_from_master(param)
            
            # Do Aggregate Algorithm to detect Byzantine worker on master node
            if opts.FedPG_BR:
                
                # Flatten the gradient vectors of each worker and put them together, shape [num_worker, -1]
                mu_vec = None
                for idx, item in enumerate(self.old_master.parameters()):
                    # Stack gradient[idx] from all worker nodes
                    grad_item = []
                    for i in range(self.world_size):
                         grad_item.append(gradient[i][idx])
                    grad_item = torch.stack(grad_item).view(self.world_size, -1)
                
                    # Concat stacked grad vector
                    if mu_vec is None:
                        mu_vec = grad_item.clone()
                    else:
                        mu_vec = torch.cat((mu_vec, grad_item.clone()), -1)
                    
                # Calculate the norm distance between each worker's gradient vector, shape [num_worker, num_worker]
                dist = euclidean_dist(mu_vec, mu_vec)
                
                # Calculate C, Variance Bound V, threshold, and alpha
                V = 2 * np.log(2 * opts.num_worker / opts.delta)
                sigma = opts.sigma

                threshold = 2 * sigma * np.sqrt(V / Batch_size)
                alpha = opts.alpha
                
                # To find MOM: |dist <= threshold| > 0.5 * num_worker
                mu_med_vec = None
                k_prime = (dist <= threshold).sum(-1) > (0.5 * self.world_size)
                
                # Computes the mom of the gradients, mu_med_vec, and
                # filter the gradients it believes to be Byzantine and store the index of non-Byzantine gradients in Good_set
                if torch.sum(k_prime) > 0:
                    mu_mean_vec = torch.mean(mu_vec[k_prime], 0).view(1, -1)
                    mu_med_vec = mu_vec[k_prime][euclidean_dist(mu_mean_vec, mu_vec[k_prime]).argmin()].view(1, -1)
                    # Applying R1 to filter
                    Good_set = euclidean_dist(mu_vec, mu_med_vec) <= 1 * threshold
                else:
                    Good_set = k_prime  # skip this step if k_prime is empty (i.e., all False)
                
                # Avoid the scenarios that Good_set is empty or can have |Gt| < (1 − α)K.
                if torch.sum(Good_set) < (1 - alpha) * self.world_size or torch.sum(Good_set) == 0:
                    
                    # Re-calculate mom of the gradients
                    k_prime = (dist <= 2 * sigma).sum(-1) > (0.5 * self.world_size)
                    if torch.sum(k_prime) > 0:
                        mu_mean_vec = torch.mean(mu_vec[k_prime], 0).view(1, -1)
                        mu_med_vec = mu_vec[k_prime][euclidean_dist(mu_mean_vec, mu_vec[k_prime]).argmin()].view(1, -1)
                        # Re-filter with R2
                        Good_set = euclidean_dist(mu_vec, mu_med_vec) <= 2 * sigma
                    else:
                        Good_set = torch.zeros(self.world_size, 1).to(opts.device).bool()
            
            # Else will treat all nodes as non-Byzantine nodes
            else:
                Good_set = torch.ones(self.world_size, 1).to(opts.device).bool()
            
            # Calculate number of good gradients for logging
            N_good = torch.sum(Good_set)
            
            # Aggregate gradients with optional system-level defense
            if self.opts.aggregation_defense != 'none':
                mu = self._aggregate_with_defense(gradient)
            else:
                if N_good > 0:
                    mu = []
                    for idx, item in enumerate(self.old_master.parameters()):
                        grad_item = []
                        for i in range(self.world_size):
                            if Good_set[i]:  # only aggregate non-Byzantine gradients
                                grad_item.append(gradient[i][idx])
                        mu.append(torch.stack(grad_item).mean(0))
                else:  # if still all nodes are detected to be Byzantine, check the sigma. If sigma is set properly, this situation will not happen.
                    mu = None
            
            # Perform gradient update in master node
            grad_array = []  # store gradients for logging

            if opts.FedPG_BR or opts.SVRPG:
                
                if opts.FedPG_BR:
                    # For n=1 to Nt ~ Geom(B/B+b) do grad update
                    b = opts.b
                    N_t = np.random.geometric(p=1 - Batch_size/(Batch_size + b))
                    
                elif opts.SVRPG:
                    b = opts.b
                    N_t = opts.N
                    
                for n in tqdm(range(N_t), desc='Master node'):
                   
                    # Calculate new gradient in master node
                    self.optimizer.zero_grad()

                    # Sample b trajectory using the latest policy (\theta_n) of master node
                    weights, new_logp, batch_rets, batch_lens, batch_states, batch_actions = self.master.collect_experience_for_training(
                        b, opts.device, record=True, sample=opts.do_sample_for_training)
                        
                    # Calculate gradient for the new policy (\theta_n)
                    loss_new = -(new_logp * weights).mean()
                    self.master.logits_net.zero_grad()
                    loss_new.backward()
                    
                    if mu:
                        old_logp = []
                        for idx, obs in enumerate(batch_states):
                            # Act in the environment with the fixed action
                            obs = env_wrapper(opts.env_name, obs)
                            _, old_log_prob = self.old_master.logits_net(torch.as_tensor(obs, dtype=torch.float32).to(opts.device), 
                                                                         fixed_action=batch_actions[idx])
                            # Store in the old_logp
                            old_logp.append(old_log_prob)
                        old_logp = torch.stack(old_logp)
                        
                        # Finding the ratio (pi_theta / pi_theta__old):
                        ratios = torch.exp(old_logp.detach() - new_logp.detach())
                        ratios_step += 1
                        
                        # Calculate gradient for the old policy (\theta_0)
                        loss_old = -(old_logp * weights * ratios).mean()
                        self.old_master.logits_net.zero_grad()
                        loss_old.backward()
                        grad_old = [item.grad for item in self.old_master.parameters()]   
                    
                        # Early stop if ratio is not within [0.995, 1.005]
                        if torch.abs(ratios.mean()) < 0.995 or torch.abs(ratios.mean()) > 1.005:
                            N_t = n
                            break
                        
                        if tb_logger is not None:
                            tb_logger.add_scalar(f'params/ratios_{run_id}', ratios.mean(), ratios_step)
                        
                        # Adjust and set the gradient for latest policy (\theta_n)
                        for idx, item in enumerate(self.master.parameters()):
                            item.grad = item.grad - grad_old[idx] + mu[idx]  # if mu is None, use grad from master 
                            grad_array += (item.grad.data.view(-1).cpu().tolist())
                        
                    # Take a gradient step
                    self.optimizer.step()
        
            else:  # GOMDP in this case
                
                b = 0
                N_t = 0
                
                # Perform gradient descent with mu vector
                for idx, item in enumerate(self.master.parameters()):
                    item.grad = mu[idx]
                    grad_array += (item.grad.data.view(-1).cpu().tolist())
                    
                # Take a gradient step
                self.optimizer.step()  
            
            # Calculate current performance for game theory metrics
            current_performance = np.mean(batch_rets)
            if step == 0:
                baseline_performance = current_performance
            
            # Calculate game theory metrics
            game_metrics = {}
            if opts.enable_game_theory:
                game_metrics = self.game_integrator.calculate_damage_metrics(
                    step, attackers, defended_clients, baseline_performance, current_performance, successful_attackers
                )
                if getattr(opts, "debug_summary", False) and step < getattr(opts, "debug_summary_rounds", 0):
                    debug_rounds.append({
                        "round": int(step),
                        "train_return": float(np.mean(batch_rets)),
                        "train_ep_len": float(np.mean(batch_lens)),
                        "performance_drop": float(game_metrics.get("performance_drop", 0.0)),
                        "theoretical_damage": float(game_metrics.get("theoretical_damage", 0.0)),
                        "num_attackers": int(game_metrics.get("num_attackers", 0)),
                        "num_defended": int(game_metrics.get("num_defended", 0)),
                        "successful_attacks": int(game_metrics.get("successful_attacks", 0)),
                        "blocked_attacks": int(game_metrics.get("blocked_attacks", 0)),
                        "defense_budget_utilization": float(game_metrics.get("defense_budget_utilization", 0.0)),
                        "attack_budget_utilization": float(game_metrics.get("attack_budget_utilization", 0.0)),
                        "avg_attacker_utility": float(game_metrics.get("avg_attacker_utility", 0.0)),
                        "current_strategy": game_metrics.get("current_strategy", ""),
                    })
                
            print('\nepoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f \t N_good: %d' %
                (epoch, np.mean(batch_loss), np.mean(batch_rets), np.mean(batch_lens), N_good))
            
            if opts.enable_game_theory and game_metrics:
                print(f'Game Theory: attackers: {game_metrics["num_attackers"]} \t defended: {game_metrics["num_defended"]} \t damage: {game_metrics["performance_drop"]:.3f} \t strategy: {game_metrics["current_strategy"]}')
            
            # Current step: number of trajectories sampled
            step += round((Batch_size * self.world_size + b * N_t) / (1 + self.world_size)) if self.world_size > 1 else Batch_size + b * N_t
            
            # Logging to tensorboard
            if tb_logger is not None:
                
                # Training log
                tb_logger.add_scalar(f'train/total_rewards_{run_id}', np.mean(batch_rets), step)
                tb_logger.add_scalar(f'train/epi_length_{run_id}', np.mean(batch_lens), step)
                tb_logger.add_scalar(f'train/loss_{run_id}', np.mean(batch_loss), step)
                # Grad log
                tb_logger.add_scalar(f'grad/grad_{run_id}', np.mean(grad_array), step)
                # Optimizer log
                tb_logger.add_scalar(f'params/lr_{run_id}', self.optimizer.param_groups[0]['lr'], step)
                tb_logger.add_scalar(f'params/N_t_{run_id}', N_t, step)

                # Game theory logging
                if opts.enable_game_theory and game_metrics:
                    tb_logger.add_scalar(f'game/num_attackers_{run_id}', game_metrics['num_attackers'], step)
                    tb_logger.add_scalar(f'game/num_defended_{run_id}', game_metrics['num_defended'], step)
                    tb_logger.add_scalar(f'game/performance_drop_{run_id}', game_metrics['performance_drop'], step)
                    tb_logger.add_scalar(f'game/theoretical_damage_{run_id}', game_metrics['theoretical_damage'], step)
                    tb_logger.add_scalar(f'game/successful_attacks_{run_id}', game_metrics['successful_attacks'], step)
                    tb_logger.add_scalar(f'game/defense_budget_util_{run_id}', game_metrics['defense_budget_utilization'], step)
                    tb_logger.add_scalar(f'game/attack_budget_util_{run_id}', game_metrics['attack_budget_utilization'], step)
                    tb_logger.add_scalar(f'game/avg_attacker_utility_{run_id}', game_metrics['avg_attacker_utility'], step)
                    tb_logger.add_scalar(f'game/current_era_{run_id}', game_metrics['current_era'], step)
                    tb_logger.add_scalar(f'game/deterrence_efficiency_{run_id}', game_metrics['deterrence_efficiency'], step)

                # Byzantine filtering log (for FedPG_BR algorithm)
                if opts.FedPG_BR:
                    # No predetermined Byzantine workers, but log filtering results
                    y_pred = (~ Good_set).view(-1).cpu().tolist()
                    
                    tb_logger.add_scalar(f'Byzantine/threshold_{run_id}', threshold, step)
                    tb_logger.add_scalar(f'grad_norm_mean/ALL_{run_id}', torch.mean(dist), step)
                    tb_logger.add_scalar(f'grad_norm_max/ALL_{run_id}', torch.max(dist), step)
                    tb_logger.add_scalar(f'Byzantine/N_good_pred_{run_id}', N_good, step)
                        
                # For performance plot
                if run_id not in self.memory.steps.keys():
                    self.memory.steps[run_id] = []
                    self.memory.eval_values[run_id] = []
                    self.memory.training_values[run_id] = []
                    if opts.enable_game_theory:
                        self.memory.game_theory_metrics[run_id] = []
                
                self.memory.steps[run_id].append(step)
                self.memory.training_values[run_id].append(np.mean(batch_rets))
                if opts.enable_game_theory and game_metrics:
                    self.memory.game_theory_metrics[run_id].append(game_metrics)
                             
            # Do validating
            eval_reward = self.start_validating(tb_logger, step, max_steps=opts.val_max_steps, 
                                              render=opts.render, run_id=run_id)
            if tb_logger is not None:
                 self.memory.eval_values[run_id].append(eval_reward)
                            
            # Save current model
            if not opts.no_saving:
                self.save(epoch, run_id)

        if getattr(opts, "debug_summary", False) and debug_lines:
            self.memory.debug_summary[run_id] = debug_lines
            print("\n=== DEBUG SUMMARY (Attack/Defense Decisions) ===")
            for line in debug_lines:
                print(line)
            if not opts.no_saving:
                debug_path = os.path.join(
                    self.opts.save_dir if self.opts.save_dir else "./results",
                    f"debug_summary_run_{run_id}.txt"
                )
                os.makedirs(os.path.dirname(debug_path), exist_ok=True)
                with open(debug_path, "w") as f:
                    f.write("\n".join(debug_lines))
                print(f"Debug summary saved to: {debug_path}")
                full_debug_path = os.path.join(
                    self.opts.save_dir if self.opts.save_dir else "./results",
                    f"debug_full_run_{run_id}.txt"
                )
                with open(full_debug_path, "w") as f:
                    f.write(f"run_id: {run_id}\n")
                    f.write(f"env_name: {opts.env_name}\n")
                    f.write(f"num_worker: {opts.num_worker}\n")
                    f.write(f"initial_defense_budget_ratio: {opts.initial_defense_budget_ratio}\n")
                    f.write(f"initial_attack_budget_ratio: {opts.initial_attack_budget_ratio}\n")
                    f.write(f"defense_strength: {getattr(opts, 'defense_strength', None)}\n")
                    f.write(f"observation_accuracy: {getattr(opts, 'observation_accuracy', None)}\n")
                    f.write(f"base_attack_intensity: {opts.base_attack_intensity}\n")
                    f.write(f"aggregation_defense: {opts.aggregation_defense}\n")
                    f.write(f"use_client_defense: {opts.use_client_defense}\n")
                    f.write(f"current_strategy: {self.game_integrator.current_strategy}\n")
                    f.write("\n--- RECORDS ---\n")
                    for rec in debug_records:
                        f.write(
                            f"round={rec['round']} attackers={rec['attackers']} defended={rec['defended']} "
                            f"overlap={rec['overlap']} successful={rec['successful']} "
                            f"defense_budget={rec['defense_budget']:.4f} attack_budget={rec['attack_budget']:.4f}\n"
                        )
                        f.write(f"  defense_costs={rec['defense_costs']}\n")
                        f.write(f"  attack_costs={rec['attack_costs']}\n")
                        f.write(f"  damage_weights={rec['damage_weights']}\n")
                        f.write(f"  current_strategy={rec['current_strategy']}\n")
                    f.write("\n--- ROUND METRICS ---\n")
                    for rec in debug_rounds:
                        f.write(
                            f"round={rec['round']} train_return={rec['train_return']:.4f} "
                            f"train_ep_len={rec['train_ep_len']:.4f} "
                            f"perf_drop={rec['performance_drop']:.4f} "
                            f"theory_damage={rec['theoretical_damage']:.4f} "
                            f"attackers={rec['num_attackers']} defended={rec['num_defended']} "
                            f"success={rec['successful_attacks']} blocked={rec['blocked_attacks']} "
                            f"def_budget_util={rec['defense_budget_utilization']:.4f} "
                            f"atk_budget_util={rec['attack_budget_utilization']:.4f} "
                            f"avg_attacker_util={rec['avg_attacker_utility']:.4f} "
                            f"strategy={rec['current_strategy']}\n"
                        )
                print(f"Full debug dump saved to: {full_debug_path}")
            if not opts.no_saving:
                summary_path = os.path.join(
                    self.opts.save_dir if self.opts.save_dir else "./results",
                    f"strategy_defense_summary_run_{run_id}.txt"
                )
                os.makedirs(os.path.dirname(summary_path), exist_ok=True)
                # Summarize defended set stats per strategy (from debug records)
                with open(summary_path, "w") as f:
                    f.write(f"run_id: {run_id}\n")
                    f.write(f"env_name: {opts.env_name}\n")
                    f.write(f"num_worker: {opts.num_worker}\n")
                    f.write(f"current_strategy: {self.game_integrator.current_strategy}\n")
                    f.write("\n--- DEFENSE SUMMARY ---\n")
                    if debug_records:
                        last = debug_records[-1]
                        defended = last["defended"]
                        damage_weights = last["damage_weights"]
                        defense_costs = last["defense_costs"]
                        total_damage = sum(damage_weights[i] for i in defended) if defended else 0.0
                        total_cost = sum(defense_costs[i] for i in defended) if defended else 0.0
                        f.write(f"defended_clients: {defended}\n")
                        f.write(f"total_defended_damage: {total_damage:.4f}\n")
                        f.write(f"total_defended_cost: {total_cost:.4f}\n")
                    else:
                        f.write("no debug records\n")
                print(f"Strategy defense summary saved to: {summary_path}")
                
    # Validate the new model   
    def start_validating(self, tb_logger=None, id=0, max_steps=1000, render=False, run_id=0, mode='human'):
        print('\nValidating...', flush=True)
        
        val_ret = 0.0
        val_len = 0.0
        
        for _ in range(self.opts.val_size):
            epi_ret, epi_len, _ = self.master.rollout(self.opts.device, max_steps=max_steps, render=render, 
                                                    sample=False, mode=mode, save_dir='./outputs/', 
                                                    filename=f'gym_{run_id}_{_}.gif')
            val_ret += epi_ret
            val_len += epi_len
        
        val_ret /= self.opts.val_size
        val_len /= self.opts.val_size
        
        print('\nGradient step: %3d \t return: %.3f \t ep_len: %.3f' %
                (id, np.mean(val_ret), np.mean(val_len)))
        
        if tb_logger is not None:
            tb_logger.add_scalar(f'validate/total_rewards_{run_id}', np.mean(val_ret), id)
            tb_logger.add_scalar(f'validate/epi_length_{run_id}', np.mean(val_len), id)
            tb_logger.close()
        
        return np.mean(val_ret)
    
    def plot_graph(self, array):
        plt.ioff()
        fig = plt.figure(figsize=(8, 4))
        y = []
        
        for id in self.memory.steps.keys():
             x = self.memory.steps[id]
             y.append(Rbf(x, array[id], function='linear')(np.arange(self.opts.max_trajectories)))
        
        mean = np.mean(y, axis=0)
        
        l, h = st.norm.interval(0.90, loc=np.mean(y, axis=0), scale=st.sem(y, axis=0))
        
        plt.plot(mean)
        plt.fill_between(range(int(self.opts.max_trajectories)), l, h, alpha=0.5)
        
        axes = plt.axes()
        axes.set_ylim([self.opts.min_reward, self.opts.max_reward])
        
        plt.xlabel("Number of Trajectories")
        plt.ylabel("Reward")
        plt.grid(True)
        plt.tight_layout()
        return fig
    
    def plot_game_theory_metrics(self):
        """Plot game theory specific metrics"""
        if not self.opts.enable_game_theory or not self.memory.game_theory_metrics:
            return None
            
        plt.ioff()
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Extract metrics for plotting
        for run_id in self.memory.game_theory_metrics.keys():
            steps = self.memory.steps[run_id]
            metrics_list = self.memory.game_theory_metrics[run_id]
            
            if not metrics_list:
                continue
            
            # Extract individual metrics
            performance_drops = [m['performance_drop'] for m in metrics_list]
            theoretical_damages = [m['theoretical_damage'] for m in metrics_list]
            num_attackers = [m['num_attackers'] for m in metrics_list]
            num_defended = [m['num_defended'] for m in metrics_list]
            defense_budget_util = [m['defense_budget_utilization'] for m in metrics_list]
            attack_budget_util = [m['attack_budget_utilization'] for m in metrics_list]
            
            # Plot metrics
            axes[0, 0].plot(steps, performance_drops, alpha=0.7, label=f'Run {run_id}')
            axes[0, 1].plot(steps, theoretical_damages, alpha=0.7, label=f'Run {run_id}')
            axes[0, 2].plot(steps, num_attackers, alpha=0.7, label=f'Run {run_id}')
            axes[1, 0].plot(steps, num_defended, alpha=0.7, label=f'Run {run_id}')
            axes[1, 1].plot(steps, defense_budget_util, alpha=0.7, label=f'Run {run_id}')
            axes[1, 2].plot(steps, attack_budget_util, alpha=0.7, label=f'Run {run_id}')
        
        # Set titles and labels
        axes[0, 0].set_title('Performance Drop')
        axes[0, 0].set_ylabel('Performance Drop')
        axes[0, 0].grid(True)
        axes[0, 0].legend()
        
        axes[0, 1].set_title('Theoretical Damage')
        axes[0, 1].set_ylabel('Theoretical Damage')
        axes[0, 1].grid(True)
        axes[0, 1].legend()
        
        axes[0, 2].set_title('Number of Attackers')
        axes[0, 2].set_ylabel('Number of Attackers')
        axes[0, 2].grid(True)
        axes[0, 2].legend()
        
        axes[1, 0].set_title('Number of Defended Clients')
        axes[1, 0].set_ylabel('Number Defended')
        axes[1, 0].grid(True)
        axes[1, 0].legend()
        
        axes[1, 1].set_title('Defense Budget Utilization')
        axes[1, 1].set_ylabel('Budget Utilization')
        axes[1, 1].grid(True)
        axes[1, 1].legend()
        
        axes[1, 2].set_title('Attack Budget Utilization')
        axes[1, 2].set_ylabel('Budget Utilization')
        axes[1, 2].grid(True)
        axes[1, 2].legend()
        
        for ax in axes.flat:
            ax.set_xlabel('Training Steps')
        
        plt.tight_layout()
        return fig
    
    def log_performance(self, tb_logger):
        eval_img = self.plot_graph(self.memory.eval_values)
        training_img = self.plot_graph(self.memory.training_values)
        tb_logger.add_figure(f'validate/performance_until_{len(self.memory.steps.keys())}_runs', 
                           eval_img, len(self.memory.steps.keys()))
        tb_logger.add_figure(f'train/performance_until_{len(self.memory.steps.keys())}_runs', 
                           training_img, len(self.memory.steps.keys()))
        
        # Log game theory metrics
        if self.opts.enable_game_theory:
            game_theory_img = self.plot_game_theory_metrics()
            if game_theory_img:
                tb_logger.add_figure(f'game/metrics_until_{len(self.memory.steps.keys())}_runs', 
                                   game_theory_img, len(self.memory.steps.keys()))
    
    def save_training_plots(self, save_dir=None, strategy_results=None):
        """Save training and validation plots as image files"""
        if not self.memory.steps:
            print("No training data to plot")
            return
            
        if save_dir is None:
            save_dir = self.opts.save_dir if self.opts.save_dir else './plots'
        
        os.makedirs(save_dir, exist_ok=True)
        
        if strategy_results is not None:
            # Plot comparison of all strategies
            self.plot_strategy_comparison(strategy_results, save_dir)
        else:
            # Plot individual run results
            plt.figure(figsize=(12, 5))
            
            # Training subplot
            plt.subplot(1, 2, 1)
            self.plot_returns(self.memory.training_values, 'Training Returns', 'Training')
            
            # Validation subplot  
            plt.subplot(1, 2, 2)
            self.plot_returns(self.memory.eval_values, 'Validation Returns', 'Validation')
            
            plt.tight_layout()
            plot_path = os.path.join(save_dir, 'training_curves.png')
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            print(f"Training plots saved to: {plot_path}")
            plt.show()  # Display the plot
            plt.close()
            
            # Save game theory plots
            if self.opts.enable_game_theory:
                game_theory_fig = self.plot_game_theory_metrics()
                if game_theory_fig:
                    game_theory_path = os.path.join(save_dir, 'game_theory_metrics.png')
                    game_theory_fig.savefig(game_theory_path, dpi=300, bbox_inches='tight')
                    print(f"Game theory plots saved to: {game_theory_path}")
                    plt.close(game_theory_fig)
        
        # Also create individual plots
        self.save_individual_plots(save_dir)
    
    def plot_strategy_comparison(self, strategy_results, save_dir):
        """Plot comparison of all defense strategies"""
        plt.figure(figsize=(15, 10))
        
        # Plot 1: Training Returns Comparison
        plt.subplot(2, 3, 1)
        for strategy_name, results in strategy_results.items():
            steps = results['steps']
            training_values = results['training_values']
            plt.plot(steps, training_values, label=strategy_name, linewidth=2, alpha=0.8)
        
        plt.title('Training Returns Comparison')
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Validation Returns Comparison
        plt.subplot(2, 3, 2)
        for strategy_name, results in strategy_results.items():
            steps = results['steps']
            eval_values = results['eval_values']
            plt.plot(steps, eval_values, label=strategy_name, linewidth=2, alpha=0.8)
        
        plt.title('Validation Returns Comparison')
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Final Performance Bar Chart
        plt.subplot(2, 3, 3)
        final_training = [results['training_values'][-1] for results in strategy_results.values()]
        final_validation = [results['eval_values'][-1] for results in strategy_results.values()]
        
        x = np.arange(len(strategy_results))
        width = 0.35
        
        plt.bar(x - width/2, final_training, width, label='Training', alpha=0.8)
        plt.bar(x + width/2, final_validation, width, label='Validation', alpha=0.8)
        
        plt.title('Final Performance Comparison')
        plt.xlabel('Strategy')
        plt.ylabel('Episode Return')
        plt.xticks(x, list(strategy_results.keys()), rotation=45)
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Game Theory Specific Plots (if available)
        game_theory_strategies = {k: v for k, v in strategy_results.items() 
                                if k != 'baseline' and 'game_theory_metrics' in v}
        
        if game_theory_strategies:
            # Plot 4: Damage Comparison
            plt.subplot(2, 3, 4)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    damage_values = [m['performance_drop'] for m in results['game_theory_metrics']]
                    plt.plot(steps, damage_values, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Performance Drop Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Performance Drop')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Plot 5: Attack Success Rate
            plt.subplot(2, 3, 5)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    attack_success = [m['successful_attacks'] for m in results['game_theory_metrics']]
                    plt.plot(steps, attack_success, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Successful Attacks Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Number of Successful Attacks')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Plot 6: Defense Efficiency
            plt.subplot(2, 3, 6)
            for strategy_name, results in game_theory_strategies.items():
                if 'game_theory_metrics' in results and results['game_theory_metrics']:
                    steps = results['steps']
                    defense_efficiency = [m['deterrence_efficiency'] for m in results['game_theory_metrics']]
                    plt.plot(steps, defense_efficiency, label=strategy_name, linewidth=2, alpha=0.8)
            
            plt.title('Defense Efficiency Comparison')
            plt.xlabel('Training Steps')
            plt.ylabel('Deterrence Efficiency')
            plt.legend()
            plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        comparison_path = os.path.join(save_dir, 'strategy_comparison.png')
        plt.savefig(comparison_path, dpi=300, bbox_inches='tight')
        print(f"Strategy comparison plots saved to: {comparison_path}")
        plt.show()
        plt.close()
        
        # Summary statistics table
        self.save_strategy_summary_table(strategy_results, save_dir)
    
    def plot_returns(self, data_dict, title, label):
        """Plot returns for multiple runs with confidence intervals"""
        if not data_dict:
            plt.title(f"{title} - No Data")
            return
            
        all_returns = []
        all_steps = []
        
        # Collect data from all runs
        for run_id in data_dict.keys():
            steps = self.memory.steps[run_id]
            returns = data_dict[run_id]
            
            # Plot individual run
            plt.plot(steps, returns, alpha=0.3, linewidth=1)
            all_steps.append(steps)
            all_returns.append(returns)
        
        # Calculate mean and confidence interval if multiple runs
        if len(all_returns) > 1:
            # Interpolate all runs to common step grid
            max_steps = max(max(steps) for steps in all_steps)
            common_steps = np.linspace(0, max_steps, 100)
            
            interpolated_returns = []
            for i, (steps, returns) in enumerate(zip(all_steps, all_returns)):
                if len(steps) > 1 and len(returns) > 1:
                    interp_returns = np.interp(common_steps, steps, returns)
                    interpolated_returns.append(interp_returns)
            
            if interpolated_returns:
                mean_returns = np.mean(interpolated_returns, axis=0)
                std_returns = np.std(interpolated_returns, axis=0)
                
                # Plot mean
                plt.plot(common_steps, mean_returns, 'b-', linewidth=2, label=f'Mean ({len(all_returns)} runs)')
                
                # Plot confidence interval
                plt.fill_between(common_steps, 
                               mean_returns - std_returns, 
                               mean_returns + std_returns, 
                               alpha=0.2, color='blue', label='±1 std')
                plt.legend()
        
        plt.title(title)
        plt.xlabel('Training Steps')
        plt.ylabel('Episode Return')
        plt.grid(True, alpha=0.3)
        
        # Set y-axis limits based on environment
        if hasattr(self.opts, 'min_reward') and hasattr(self.opts, 'max_reward'):
            plt.ylim(self.opts.min_reward, self.opts.max_reward)
    
    def save_individual_plots(self, save_dir):
        """Save separate training and validation plots"""
        # Training plot
        plt.figure(figsize=(10, 6))
        self.plot_returns(self.memory.training_values, 'Training Returns vs Steps', 'Training')
        plt.savefig(os.path.join(save_dir, 'training_returns.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
        # Validation plot
        plt.figure(figsize=(10, 6))
        self.plot_returns(self.memory.eval_values, 'Validation Returns vs Steps', 'Validation')
        plt.savefig(os.path.join(save_dir, 'validation_returns.png'), dpi=300, bbox_inches='tight')
        plt.close()
        
    def save_strategy_summary_table(self, strategy_results, save_dir):
        """Save summary table comparing all strategies"""
        import pandas as pd
        
        summary_data = []
        for strategy_name, results in strategy_results.items():
            final_training = results['training_values'][-1]
            final_validation = results['eval_values'][-1]
            
            row = {
                'Strategy': strategy_name,
                'Final_Training_Return': final_training,
                'Final_Validation_Return': final_validation,
                'Training_Improvement': results['training_values'][-1] - results['training_values'][0],
                'Validation_Improvement': results['eval_values'][-1] - results['eval_values'][0]
            }
            
            # Add game theory metrics if available
            if 'game_theory_metrics' in results and results['game_theory_metrics']:
                metrics = results['game_theory_metrics']
                row.update({
                    'Total_Performance_Drop': sum(m['performance_drop'] for m in metrics),
                    'Avg_Successful_Attacks': np.mean([m['successful_attacks'] for m in metrics]),
                    'Avg_Defense_Efficiency': np.mean([m['deterrence_efficiency'] for m in metrics]),
                    'Total_Theoretical_Damage': sum(m['theoretical_damage'] for m in metrics)
                })
            
            summary_data.append(row)
        
        df = pd.DataFrame(summary_data)
        csv_path = os.path.join(save_dir, 'strategy_comparison_summary.csv')
        df.to_csv(csv_path, index=False)
        print(f"Strategy comparison summary saved to: {csv_path}")
        
        print(f"Individual plots saved to: {save_dir}/training_returns.png and validation_returns.png")

def run_strategy_comparison(opts):
    """Run all strategies and compare results"""
    import pprint
    pprint.pprint(vars(opts))
    
    all_strategy_results = {}
    base_save_dir = opts.save_dir if opts.save_dir else './results'
    
    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"RUNNING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")
        
        # Configure for this strategy
        system_level = {'fltg', 'fedgreed'}
        if strategy_name == 'clean_baseline':
            # Clean run: no game theory attacks or defenses
            opts.enable_game_theory = False
            opts.aggregation_defense = 'none'
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name in system_level:
            opts.enable_game_theory = True
            opts.aggregation_defense = strategy_name
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name == 'baseline':
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = 'no_defense'
        elif strategy_name in {'stackelberg_fltg', 'ucb_fltg', 'thompson_fltg',
                               'stackelberg_fedgreed', 'ucb_fedgreed', 'thompson_fedgreed'}:
            _combined_map = {
                'stackelberg_fltg':      ('true_stackelberg', 'fltg'),
                'ucb_fltg':              ('ucb',              'fltg'),
                'thompson_fltg':         ('thompson_sampling', 'fltg'),
                'stackelberg_fedgreed':  ('true_stackelberg',  'fedgreed'),
                'ucb_fedgreed':          ('ucb',               'fedgreed'),
                'thompson_fedgreed':     ('thompson_sampling',  'fedgreed'),
            }
            client_strategy, agg_defense = _combined_map[strategy_name]
            opts.enable_game_theory = True
            opts.aggregation_defense = agg_defense
            opts.use_client_defense = True
            current_strategy = client_strategy
        else:
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = strategy_name

        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Setup save directory for this strategy
        if not opts.no_saving:
            strategy_save_dir = os.path.join(base_save_dir, strategy_name)
            os.makedirs(strategy_save_dir, exist_ok=True)
            opts.save_dir = strategy_save_dir
            opts.save_dir = strategy_save_dir
            
            # Save arguments for this strategy
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            opts_dict['current_strategy'] = current_strategy
            with open(os.path.join(strategy_save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = current_strategy
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
            
        all_strategy_results[strategy_name] = strategy_results
        
        print(f"\nStrategy {strategy_name.upper()} completed!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")

        # Restore base save dir for next strategy
        opts.save_dir = base_save_dir
    
    # Generate comparative plots
    print(f"\n{'='*60}")
    print(f"GENERATING COMPARATIVE ANALYSIS")
    print(f"{'='*60}")
    
    # Use first agent for plotting (all have same plotting capabilities)
    agent.save_training_plots(opts.save_dir, all_strategy_results)
    
    # Print final comparison
    print(f"\n{'='*80}")
    print(f"FINAL STRATEGY COMPARISON RESULTS")
    print(f"{'='*80}")
    
    for strategy_name, results in all_strategy_results.items():
        if results['training_values'] and results['eval_values']:
            print(f"{strategy_name:15s}: Train={results['training_values'][-1]:6.2f}, "
                  f"Val={results['eval_values'][-1]:6.2f}, "
                  f"Improvement={results['training_values'][-1] - results['training_values'][0]:+6.2f}")
    
    return all_strategy_results

    def save_game_theory_results(self, save_dir=None):
        """Save detailed game theory results to CSV"""
        if not self.opts.enable_game_theory or not self.memory.game_theory_metrics:
            return
            
        if save_dir is None:
            save_dir = self.opts.save_dir if self.opts.save_dir else './results'
        
        os.makedirs(save_dir, exist_ok=True)
        
        for run_id in self.memory.game_theory_metrics.keys():
            csv_filename = os.path.join(save_dir, f'game_theory_run_{run_id}.csv')
            
            with open(csv_filename, mode="w", newline="") as csv_file:
                writer = csv.writer(csv_file)
                writer.writerow([
                    "step", "era", "era_round", "performance_drop", "theoretical_damage",
                    "num_attackers", "num_defended", "successful_attacks", "blocked_attacks",
                    "avg_attacker_utility", "defense_cost_used", "attack_cost_used",
                    "defense_budget_util", "attack_budget_util", "deterrence_efficiency",
                    "current_strategy", "damage_weight_range"
                ])
                
                steps = self.memory.steps[run_id]
                metrics_list = self.memory.game_theory_metrics[run_id]
                
                for i, step in enumerate(steps):
                    if i < len(metrics_list):
                        m = metrics_list[i]
                        writer.writerow([
                            step, m['current_era'], m['era_round'], m['performance_drop'],
                            m['theoretical_damage'], m['num_attackers'], m['num_defended'],
                            m['successful_attacks'], m['blocked_attacks'], m['avg_attacker_utility'],
                            m['defense_cost_used'], m['attack_cost_used'], m['defense_budget_utilization'],
                            m['attack_budget_utilization'], m['deterrence_efficiency'],
                            m['current_strategy'], m['damage_weight_range']
                        ])
            
            print(f"Game theory results for run {run_id} saved to: {csv_filename}")

# =============================================================================
# MAIN EXECUTION FUNCTION
# =============================================================================

def run(opts):
    """Main execution function with optional strategy comparison"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies and compare
        return run_strategy_comparison(opts)
    else:
        # Run single strategy (original behavior)
        import pprint
        pprint.pprint(vars(opts))
        
        # Setup tensorboard
        if not opts.no_tb:
            tb_writer = SummaryWriter(opts.log_dir)
        else:
            tb_writer = None

        # Optionally configure tensorboard
        if not opts.no_saving and not os.path.exists(opts.save_dir):
            os.makedirs(opts.save_dir)

        # Save arguments so exact configuration can always be found
        if not opts.no_saving:
            # Create a copy of opts without non-serializable objects
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])  # Convert device to string
            with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Figure out the RL
        agent = Agent(opts)
        
        # Do validation only
        if opts.eval_only:
            # Set the random seed
            torch.manual_seed(opts.seed)
            np.random.seed(opts.seed)
            
            # Load data from load_path
            if opts.load_path is not None:
                agent.load(opts.load_path)
            
            agent.start_validating(tb_writer, 0, opts.val_max_steps, opts.render, mode=opts.mode)
            
        else:
            for run_id in opts.seeds:
                # Set the random seed
                torch.manual_seed(run_id)
                np.random.seed(run_id)
                
                nn_parms_worker = Worker(
                    id=0,
                    is_Byzantine=False,
                    env_name=opts.env_name,
                    gamma=opts.gamma,
                    hidden_units=opts.hidden_units, 
                    activation=opts.activation, 
                    output_activation=opts.output_activation,
                    max_epi_len=opts.max_epi_len,
                    opts=opts
                ).to(opts.device)
                
                # Load data from random policy
                model_actor = get_inner_model(agent.master.logits_net)
                model_actor.load_state_dict({**model_actor.state_dict(), 
                                           **get_inner_model(nn_parms_worker.logits_net).state_dict()})
            
                # Start training here
                agent.start_training(tb_writer, run_id)
                if tb_writer:
                    agent.log_performance(tb_writer)
            
            # Save training plots after all runs complete
            print("\n" + "="*60)
            print("Training completed! Generating plots...")
            agent.save_training_plots()
            
            # Save game theory results
            if opts.enable_game_theory:
                agent.save_game_theory_results()
                print("Game theory results saved!")
            
            print("="*60)

# =============================================================================
# MAIN EXECUTION
# =============================================================================
import os
import json
import pickle
import csv
import numpy as np
from typing import Dict, Any, Optional

class StrategyCheckpointManager:
    """Manages saving/loading of strategy results to avoid re-running completed experiments"""
    
    def __init__(self, base_dir: str, experiment_id: str = None):
        """
        Args:
            base_dir: Base directory for saving results
            experiment_id: Unique identifier for this experiment run
        """
        if experiment_id is None:
            experiment_id = f"exp_{int(time.time())}"
        
        self.checkpoint_dir = os.path.join(base_dir, "checkpoints", experiment_id)
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        print(f"Checkpoint system initialized: {self.checkpoint_dir}")
    
    def get_strategy_checkpoint_path(self, strategy_name: str) -> str:
        """Get the checkpoint file path for a strategy"""
        return os.path.join(self.checkpoint_dir, f"{strategy_name}_checkpoint.pkl")
    
    def get_strategy_csv_path(self, strategy_name: str) -> str:
        """Get the CSV file path for a strategy"""
        return os.path.join(self.checkpoint_dir, f"{strategy_name}_results.csv")
    
    def strategy_exists(self, strategy_name: str) -> bool:
        """Check if strategy results already exist"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        csv_path = self.get_strategy_csv_path(strategy_name)
        
        exists = os.path.exists(checkpoint_path) and os.path.exists(csv_path)
        if exists:
            print(f"Found existing results for strategy: {strategy_name}")
        return exists
    
    def save_strategy_results(self, strategy_name: str, results: Dict[str, Any], 
                            config_dict: Dict[str, Any] = None):
        """Save strategy results to checkpoint files"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        csv_path = self.get_strategy_csv_path(strategy_name)
        
        # Save binary checkpoint with all data
        checkpoint_data = {
            'strategy_name': strategy_name,
            'results': results,
            'config': config_dict,
            'timestamp': time.time()
        }
        
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint_data, f)
        
        # Save human-readable CSV
        self._save_results_csv(csv_path, strategy_name, results)
        
        print(f"Saved results for strategy: {strategy_name}")
        print(f"  Checkpoint: {checkpoint_path}")
        print(f"  CSV: {csv_path}")
    
    def load_strategy_results(self, strategy_name: str) -> Optional[Dict[str, Any]]:
        """Load strategy results from checkpoint"""
        checkpoint_path = self.get_strategy_checkpoint_path(strategy_name)
        
        if not os.path.exists(checkpoint_path):
            return None
        
        try:
            with open(checkpoint_path, 'rb') as f:
                checkpoint_data = pickle.load(f)
            
            print(f"Loaded existing results for strategy: {strategy_name}")
            return checkpoint_data['results']
        
        except Exception as e:
            print(f"Error loading checkpoint for {strategy_name}: {e}")
            return None
    
    def _save_results_csv(self, csv_path: str, strategy_name: str, results: Dict[str, Any]):
        """Save results in CSV format for easy inspection"""
        try:
            with open(csv_path, 'w', newline='') as f:
                writer = csv.writer(f)
                
                # Write header
                header = ['step', 'training_return', 'validation_return']
                if 'game_theory_metrics' in results:
                    header.extend(['num_attackers', 'num_defended', 'performance_drop', 
                                 'theoretical_damage', 'defense_budget_util', 'attack_budget_util'])
                
                writer.writerow(header)
                
                # Write data rows
                steps = results.get('steps', [])
                training_values = results.get('training_values', [])
                eval_values = results.get('eval_values', [])
                game_metrics = results.get('game_theory_metrics', [])
                
                max_len = max(len(steps), len(training_values), len(eval_values))
                
                for i in range(max_len):
                    row = [
                        steps[i] if i < len(steps) else '',
                        training_values[i] if i < len(training_values) else '',
                        eval_values[i] if i < len(eval_values) else ''
                    ]
                    
                    if game_metrics and i < len(game_metrics):
                        metrics = game_metrics[i]
                        row.extend([
                            metrics.get('num_attackers', ''),
                            metrics.get('num_defended', ''),
                            metrics.get('performance_drop', ''),
                            metrics.get('theoretical_damage', ''),
                            metrics.get('defense_budget_utilization', ''),
                            metrics.get('attack_budget_utilization', '')
                        ])
                    elif 'game_theory_metrics' in results:
                        row.extend(['', '', '', '', '', ''])
                    
                    writer.writerow(row)
                    
        except Exception as e:
            print(f"Warning: Could not save CSV for {strategy_name}: {e}")
    
    def list_completed_strategies(self) -> list:
        """List all strategies that have been completed"""
        completed = []
        if os.path.exists(self.checkpoint_dir):
            for filename in os.listdir(self.checkpoint_dir):
                if filename.endswith('_checkpoint.pkl'):
                    strategy_name = filename.replace('_checkpoint.pkl', '')
                    completed.append(strategy_name)
        return completed
    
    def get_experiment_summary(self) -> Dict[str, Any]:
        """Get summary of all completed strategies in this experiment"""
        summary = {
            'experiment_dir': self.checkpoint_dir,
            'completed_strategies': self.list_completed_strategies(),
            'total_completed': len(self.list_completed_strategies())
        }
        return summary


def run_strategy_comparison_with_checkpoints(opts):
    """Enhanced strategy comparison with checkpoint support"""
    import pprint
    pprint.pprint(vars(opts))
    
    # Initialize checkpoint manager
    base_save_dir = opts.save_dir if opts.save_dir else './results'
    checkpoint_manager = StrategyCheckpointManager(
        base_dir=base_save_dir,
        experiment_id=f"{opts.env_name}_workers{opts.num_worker}_seed{opts.seed}"
    )
    
    # Print experiment summary
    summary = checkpoint_manager.get_experiment_summary()
    print(f"\n{'='*80}")
    print(f"CHECKPOINT SYSTEM STATUS")
    print(f"{'='*80}")
    print(f"Experiment directory: {summary['experiment_dir']}")
    print(f"Completed strategies: {summary['completed_strategies']}")
    print(f"Total completed: {summary['total_completed']}")
    
    all_strategy_results = {}
    
    print(f"\n{'='*80}")
    print(f"RUNNING STRATEGY COMPARISON: {len(opts.strategies_to_compare)} strategies")
    print(f"{'='*80}")
    
    for strategy_name in opts.strategies_to_compare:
        print(f"\n{'='*60}")
        print(f"CHECKING STRATEGY: {strategy_name.upper()}")
        print(f"{'='*60}")
        
        # Check if results already exist
        if checkpoint_manager.strategy_exists(strategy_name):
            print(f"LOADING EXISTING RESULTS for {strategy_name}...")
            existing_results = checkpoint_manager.load_strategy_results(strategy_name)
            if existing_results:
                all_strategy_results[strategy_name] = existing_results
                print(f"Successfully loaded {strategy_name} results!")
                continue
        
        print(f"RUNNING NEW EXPERIMENT for {strategy_name}...")
        
        # Configure for this strategy
        system_level = {'fltg', 'fedgreed'}
        if strategy_name == 'clean_baseline':
            # Clean run: no game theory attacks or defenses
            opts.enable_game_theory = False
            opts.aggregation_defense = 'none'
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name in system_level:
            opts.enable_game_theory = True
            opts.aggregation_defense = strategy_name
            opts.use_client_defense = False
            current_strategy = 'no_defense'
        elif strategy_name == 'baseline':
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = 'no_defense'
        elif strategy_name in {'stackelberg_fltg', 'ucb_fltg', 'thompson_fltg',
                               'stackelberg_fedgreed', 'ucb_fedgreed', 'thompson_fedgreed'}:
            _combined_map = {
                'stackelberg_fltg':      ('true_stackelberg', 'fltg'),
                'ucb_fltg':              ('ucb',              'fltg'),
                'thompson_fltg':         ('thompson_sampling', 'fltg'),
                'stackelberg_fedgreed':  ('true_stackelberg',  'fedgreed'),
                'ucb_fedgreed':          ('ucb',               'fedgreed'),
                'thompson_fedgreed':     ('thompson_sampling',  'fedgreed'),
            }
            client_strategy, agg_defense = _combined_map[strategy_name]
            opts.enable_game_theory = True
            opts.aggregation_defense = agg_defense
            opts.use_client_defense = True
            current_strategy = client_strategy
        else:
            opts.enable_game_theory = True
            opts.aggregation_defense = 'none'
            opts.use_client_defense = True
            current_strategy = strategy_name

        # Setup tensorboard for this strategy
        if not opts.no_tb:
            tb_writer = SummaryWriter(os.path.join(opts.log_dir, strategy_name))
        else:
            tb_writer = None

        # Setup save directory for this strategy
        if not opts.no_saving:
            strategy_save_dir = os.path.join(base_save_dir, strategy_name)
            os.makedirs(strategy_save_dir, exist_ok=True)
            opts.save_dir = strategy_save_dir
            
            # Save arguments for this strategy
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            opts_dict['current_strategy'] = current_strategy
            with open(os.path.join(strategy_save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent for this strategy
        agent = Agent(opts)
        
        # Set specific strategy if game theory is enabled
        if opts.enable_game_theory and hasattr(agent, 'game_integrator'):
            agent.game_integrator.current_strategy = current_strategy
        
        # Run training for all seeds
        for run_id in opts.seeds:
            torch.manual_seed(run_id)
            np.random.seed(run_id)
            
            # Initialize model
            nn_parms_worker = Worker(
                id=0, env_name=opts.env_name, gamma=opts.gamma,
                hidden_units=opts.hidden_units, activation=opts.activation, 
                output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                opts=opts
            ).to(opts.device)
            
            # Load random policy
            model_actor = get_inner_model(agent.master.logits_net)
            model_actor.load_state_dict({**model_actor.state_dict(), 
                                       **get_inner_model(nn_parms_worker.logits_net).state_dict()})
        
            # Start training
            agent.start_training(tb_writer, run_id)
            if tb_writer:
                agent.log_performance(tb_writer)
        
        # Store and save results for this strategy
        strategy_results = {
            'steps': agent.memory.steps[opts.seeds[0]] if opts.seeds[0] in agent.memory.steps else [],
            'training_values': agent.memory.training_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.training_values else [],
            'eval_values': agent.memory.eval_values[opts.seeds[0]] if opts.seeds[0] in agent.memory.eval_values else []
        }
        
        if opts.enable_game_theory and hasattr(agent.memory, 'game_theory_metrics') and opts.seeds[0] in agent.memory.game_theory_metrics:
            strategy_results['game_theory_metrics'] = agent.memory.game_theory_metrics[opts.seeds[0]]
        
        # Save checkpoint
        config_dict = vars(opts).copy()
        config_dict['device'] = str(config_dict['device'])  # Make serializable
        checkpoint_manager.save_strategy_results(strategy_name, strategy_results, config_dict)
        
        all_strategy_results[strategy_name] = strategy_results
        
        print(f"\nStrategy {strategy_name.upper()} completed and saved!")
        if strategy_results['training_values']:
            print(f"Final training return: {strategy_results['training_values'][-1]:.2f}")
        if strategy_results['eval_values']:
            print(f"Final validation return: {strategy_results['eval_values'][-1]:.2f}")

        # Restore base save dir for next strategy
        opts.save_dir = base_save_dir
    
    # Generate comparative plots
    print(f"\n{'='*60}")
    print(f"GENERATING COMPARATIVE ANALYSIS")
    print(f"{'='*60}")
    
    # Create agent for plotting (use last agent)
    if 'agent' in locals():
        agent.save_training_plots(opts.save_dir, all_strategy_results)
    
    # Print final comparison
    print(f"\n{'='*80}")
    print(f"FINAL STRATEGY COMPARISON RESULTS")
    print(f"{'='*80}")
    
    for strategy_name, results in all_strategy_results.items():
        if results['training_values'] and results['eval_values']:
            print(f"{strategy_name:15s}: Train={results['training_values'][-1]:6.2f}, "
                  f"Val={results['eval_values'][-1]:6.2f}, "
                  f"Improvement={results['training_values'][-1] - results['training_values'][0]:+6.2f}")
    
    # Save final summary
    final_summary = checkpoint_manager.get_experiment_summary()
    final_summary['final_results'] = {k: {
        'final_train': v['training_values'][-1] if v['training_values'] else 0,
        'final_val': v['eval_values'][-1] if v['eval_values'] else 0
    } for k, v in all_strategy_results.items()}
    
    summary_path = os.path.join(checkpoint_manager.checkpoint_dir, 'experiment_summary.json')
    with open(summary_path, 'w') as f:
        json.dump(final_summary, f, indent=2)
    
    print(f"\nExperiment complete! All results saved to: {checkpoint_manager.checkpoint_dir}")
    
    return all_strategy_results


# Update the main run function to use checkpoints
def run_with_checkpoints(opts):
    """Main execution function with checkpoint support"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies with checkpoint support
        return run_strategy_comparison_with_checkpoints(opts)
    else:
        # Original single strategy run (could also add checkpoint support here)
        return run(opts)


def run_with_checkpoints_and_csv(opts):
    """Main execution function with checkpoint and CSV support"""
    
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Run all strategies with checkpoint and CSV support
        return run_strategy_comparison_with_csv_export(opts)
    else:
        # Original single strategy run with CSV export
        import pprint
        pprint.pprint(vars(opts))
        
        # Setup tensorboard
        if not opts.no_tb:
            tb_writer = SummaryWriter(opts.log_dir)
        else:
            tb_writer = None

        # Setup directories
        if not opts.no_saving and not os.path.exists(opts.save_dir):
            os.makedirs(opts.save_dir)

        # Save arguments
        if not opts.no_saving:
            opts_dict = vars(opts).copy()
            opts_dict['device'] = str(opts_dict['device'])
            with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
                json.dump(opts_dict, f, indent=True)

        # Create agent
        agent = Agent(opts)
        
        # Determine strategy name
        if opts.enable_game_theory:
            strategy_name = getattr(agent.game_integrator, 'current_strategy', 'game_theory')
        else:
            strategy_name = 'baseline'
        
        # Run training/evaluation
        if opts.eval_only:
            torch.manual_seed(opts.seed)
            np.random.seed(opts.seed)
            
            if opts.load_path is not None:
                agent.load(opts.load_path)
            
            agent.start_validating(tb_writer, 0, opts.val_max_steps, opts.render, mode=opts.mode)
            
        else:
            for run_id in opts.seeds:
                torch.manual_seed(run_id)
                np.random.seed(run_id)
                
                nn_parms_worker = Worker(
                    id=0, env_name=opts.env_name, gamma=opts.gamma,
                    hidden_units=opts.hidden_units, activation=opts.activation, 
                    output_activation=opts.output_activation, max_epi_len=opts.max_epi_len,
                    opts=opts
                ).to(opts.device)
                
                model_actor = get_inner_model(agent.master.logits_net)
                model_actor.load_state_dict({**model_actor.state_dict(), 
                                           **get_inner_model(nn_parms_worker.logits_net).state_dict()})
            
                agent.start_training(tb_writer, run_id)
                if tb_writer:
                    agent.log_performance(tb_writer)
            
            # Save training plots and CSV
            print("\n" + "="*60)
            print("Training completed! Generating plots and saving CSV...")
            agent.save_training_plots()
            agent.save_performance_to_csv(strategy_name)
            
            if opts.enable_game_theory:
                agent.save_game_theory_results()
                print("Game theory results saved!")
            
            print("CSV performance data saved!")
            print("="*60)
        
        return agent

if __name__ == "__main__":

    # ------------------------------------------------------------------ #
    #  CLI argument parser – lets ablation scripts override Config fields  #
    # ------------------------------------------------------------------ #
    import argparse

    def _str2bool(v):
        return str(v).lower() in ('true', '1', 'yes')

    _parser = argparse.ArgumentParser(add_help=True)
    _parser.add_argument('--defense-budget',      type=float,     default=None, dest='initial_defense_budget_ratio')
    _parser.add_argument('--attack-budget',       type=float,     default=None, dest='initial_attack_budget_ratio')
    _parser.add_argument('--attack-intensity',    type=float,     default=None, dest='base_attack_intensity')
    _parser.add_argument('--reshuffle-freq',      type=int,       default=None, dest='reshuffle_frequency')
    _parser.add_argument('--attack-style',        type=str,       default=None, dest='attack_style',
                         choices=['action_flip', 'reward_poison', 'obs_noise', 'grad_noise'])
    # Per-attack-style intensity knobs
    _parser.add_argument('--attack-flip-prob',    type=float,     default=None, dest='attack_flip_prob')
    _parser.add_argument('--obs-noise-std',       type=float,     default=None, dest='obs_noise_std')
    _parser.add_argument('--reward-poison-scale', type=float,     default=None, dest='reward_poison_scale')
    _parser.add_argument('--obs-accuracy',        type=float,     default=None, dest='observation_accuracy')
    _parser.add_argument('--defense-strength',    type=float,     default=None, dest='defense_strength')
    _parser.add_argument('--single-attacker',     type=_str2bool, default=None, dest='single_best_attacker_only')
    _parser.add_argument('--num-clients',         type=int,       default=None, dest='num_worker')
    _parser.add_argument('--multiple-run',        type=int,       default=None, dest='multiple_run')
    _parser.add_argument('--max-rounds',          type=int,       default=None, dest='max_trajectories')
    _parser.add_argument('--seed',                type=int,       default=None, dest='seed')
    _parser.add_argument('--label',               type=str,       default='',   dest='run_label',
                         help='Suffix appended to output dir for ablation disambiguation')
    _args = _parser.parse_args()

    # Create base configuration
    opts = Config()

    # Apply only the args that were explicitly provided
    for _key, _val in vars(_args).items():
        if _key == 'run_label':
            continue
        if _val is not None:
            setattr(opts, _key, _val)

    # Rebuild derived fields that depend on overridden params
    opts.seeds = (np.arange(opts.multiple_run) + opts.seed).tolist()
    if not opts.no_saving:
        opts.save_dir = (
            f'outputs_{opts.env_name}_workers{opts.num_worker}'
            f'_rounds{opts.max_trajectories}'
            f'_shuffle_{opts.reshuffle_frequency}'
            f'_single_attacker_{opts.single_best_attacker_only}'
            f'_attack_{opts.base_attack_intensity}'
            f'_seed{opts.seed}'
            f'_attack_style_{opts.attack_style}'
        )
    if not opts.no_tb:
        opts.log_dir = (
            f'logs_{opts.env_name}_workers{opts.num_worker}'
            f'_rounds{opts.max_trajectories}'
            f'_shuffle_{opts.reshuffle_frequency}'
            f'_single_attacker_{opts.single_best_attacker_only}'
            f'_attack_{opts.base_attack_intensity}'
            f'_seed{opts.seed}'
            f'_attack_style_{opts.attack_style}'
        )

    # Append label for params not captured in the dir name
    # (defense_budget, attack_budget, obs_accuracy, damage_scenario, defense_strength)
    if _args.run_label:
        if opts.save_dir:
            opts.save_dir = opts.save_dir + f'__{_args.run_label}'
        if opts.log_dir:
            opts.log_dir  = opts.log_dir  + f'__{_args.run_label}'

    # Print which algorithm is running
    assert opts.SVRPG + opts.FedPG_BR <= 1
    
    if opts.enable_game_theory:
        print('='*80)
        print('FEDERATED REINFORCEMENT LEARNING WITH STACKELBERG GAME THEORY')
        print('='*80)
        print(f'Environment: {opts.env_name}')
        print(f'Workers: {opts.num_worker}')
        print(f'Byzantine Workers: {opts.num_Byzantine}')
        print(f'Game Theory Scenario: {opts.damage_scenario}')
        print(f'Defense Budget Ratio: {opts.initial_defense_budget_ratio}')
        print(f'Attack Budget Ratio: {opts.initial_attack_budget_ratio}')
        print(f'Base Attack Intensity: {opts.base_attack_intensity}')
        print(f'Observation Accuracy: {getattr(opts, "observation_accuracy", 0.8)}')
        print(f'Reshuffle Frequency: every {opts.reshuffle_frequency} rounds')
        print('='*80)
        if opts.SVRPG + opts.FedPG_BR == 0:
            print('Running: GOMDP with Stackelberg Game Theory')
        elif opts.FedPG_BR:
            print('Running: FT-FedScsPG with Stackelberg Game Theory')
        else:
            print('Running: SVRPG with Stackelberg Game Theory')
    else:
        print('run GPMDP\n' if opts.SVRPG + opts.FedPG_BR == 0 else ('run FT-FedScsPG\n' if opts.FedPG_BR else 'run SVRPG\n'))
    
    # MODIFIED: Use enhanced CSV export version
    if hasattr(opts, 'compare_all_strategies') and opts.compare_all_strategies:
        # Use the enhanced version with CSV export
        all_strategy_results = run_strategy_comparison_with_csv_export(opts)
        
        print(f"\n{'='*80}")
        print(f"CSV FILES SAVED!")
        print(f"{'='*80}")
        print(f"Check the following directory for CSV files:")
        csv_dir = os.path.join(opts.save_dir if opts.save_dir else './results', 'performance_csv')
        print(f"  {csv_dir}")
        print(f"Files generated:")
        print(f"  - Individual strategy files: [strategy_name]_performance.csv")
        print(f"  - Combined file: all_strategies_performance.csv")
        print(f"  - Training data: training_returns_all_strategies.csv")
        print(f"  - Validation data: validation_returns_all_strategies.csv")
        print(f"  - Plotting ready: all_returns_plotting_ready.csv")
        
    else:
        # Run the original training with CSV saving
        run_with_checkpoints(opts)
