import json
import argparse
from pathlib import Path

from collections import Counter, defaultdict
import numpy as np

AVAILABLE_MODELS = [
    "Meta-Llama-3.1-8B-Instruct",
    "Llama-3.3-70B-Instruct",
    "Llama-3.1-70B-Instruct",
    "Meta-Llama-3.1-405B-Instruct-FP8",
    "Qwen2.5-Coder-32B-Instruct",
    "Qwen2.5-32B-Instruct",
    "gpt-4o-v2",
    "gpt-4o-mini",
    "claude-37"
]

LEVELS = [f'level_{i}' for i in range(13)]

AGENT_TYPES = [
    'individual', 
    'debate',
    'orchestrator',
    'planner'
    ]

def read_results(args):

    if args.level == 'all':
        levels = LEVELS
    else:
        levels = [args.level]

    files = []
    for level in levels:
        if args.ablation:
            results_dir = 'ablation/'
        else:
            results_dir = args.results_dir

        if args.agent_type == 'planner':
            filename = Path(results_dir) / args.model / args.agent_type / args.executor_model / str(args.num_agents) / f'result_{level}.json'
        else:
            filename = Path(results_dir) / args.model / args.agent_type / str(args.num_agents) / f'result_{level}.json'
        files.append(filename)

    results, existing_is = [], []
    for i, file in enumerate(files):
        print(f"Loading results from {file}")
        print("Exists? : ", file.exists())
        if file.exists():
            existing_is.append(i)
            with open(file, 'r') as f:
                results.append(json.load(f))
            
    return results, [levels[i] for i in existing_is]

def print_task_statistics(data, alpha):
    """
    Calculate and print statistics for accomplished tasks across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Initialize a list to store Counters from all episodes
    all_counters = []

    # Collect counters from all episodes
    for eps_data in data[alpha]:
        all_counters.append(Counter(eps_data['acomplished_task_list']))

    # Combine all unique tasks
    all_tasks = set()
    for counter in all_counters:
        all_tasks.update(counter.keys())

    # Create arrays for each task
    task_stats = {}
    for task in all_tasks:
        # Get counts for this task across all episodes (0 if task not present in episode)
        counts = [counter.get(task, 0) for counter in all_counters]
        
        task_stats[task] = {
            'mean': np.mean(counts),
            'std': np.std(counts),
            'counts': counts
        }

    # Print results
    print("\nTask Statistics Across Episodes:")
    print("-" * 50)
    for task, stats in task_stats.items():
        print(f"{task}:")
        print(f"  Mean: {stats['mean']:.2f} | Std: {stats['std']:.2f}")

    return task_stats

def print_noop_statistics(data, alpha):
    """
    Calculate and print noop count statistics per agent across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Initialize dictionary to store noop counts per agent
    agent_noop_counts = defaultdict(list)
    total_noop_counts = []
    
    # Collect noop counts from all episodes
    for eps_data in data[alpha]:
        episode_total = 0
        
        if 'agent_noop_counts' in eps_data:
            # If the data has per-agent noop counts dictionary
            for agent_id, count in eps_data['agent_noop_counts'].items():
                if isinstance(count, (int, float)):
                    agent_noop_counts[agent_id].append(count)
                    episode_total += count
                elif isinstance(count, list):
                    # If it's a list, sum it first
                    agent_noop_counts[agent_id].append(sum(count))
                    episode_total += sum(count)
            
        elif 'noop_count' in eps_data:
            # If the data has a noop_count field
            if isinstance(eps_data['noop_count'], (int, float)):
                # Single value - add to total
                episode_total = eps_data['noop_count']
                # We don't know which agent, so we can't add to agent_noop_counts
            
            elif isinstance(eps_data['noop_count'], list):
                # List of counts per agent
                for i, count in enumerate(eps_data['noop_count']):
                    agent_id = str(i)  # Use index as agent ID
                    agent_noop_counts[agent_id].append(count)
                    episode_total += count
        
        # Add episode total to the list of totals
        total_noop_counts.append(episode_total)
    
    # Print results
    print("\nNoop Count Statistics:")
    print("-" * 50)
    
    # Print per-agent statistics if available
    if agent_noop_counts:
        for agent_id, counts in sorted(agent_noop_counts.items()):
            mean = np.mean(counts)
            std = np.std(counts)
            total = sum(counts)
            print(f"Agent {agent_id}:")
            print(f"  Mean: {mean:.2f} | Std: {std:.2f} | Total: {total}")
    
    # Print total statistics
    if total_noop_counts:
        mean = np.mean(total_noop_counts)
        std = np.std(total_noop_counts)
        total = sum(total_noop_counts)
        print(f"Total Noop Count:")
        print(f"  Mean: {mean:.2f} | Std: {std:.2f} | Total: {total}")
    
    return {
        'per_agent': {agent_id: {'mean': np.mean(counts), 'std': np.std(counts), 'total': sum(counts)} 
                     for agent_id, counts in agent_noop_counts.items()},
        'total': {'mean': np.mean(total_noop_counts) if total_noop_counts else 0, 
                 'std': np.std(total_noop_counts) if total_noop_counts else 0,
                 'total': sum(total_noop_counts) if total_noop_counts else 0}
    }

def print_soc_statistics(data, alpha):
    """
    Calculate and print SOC (Success Over Completion) statistics across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Calculate SOC for each episode
    soc_values = []
    for eps_data in data[alpha]:
        if eps_data['success'] + eps_data['failed'] == 0:
            soc = 0
        else:
            soc = eps_data['success'] / (eps_data['success'] + eps_data['failed'])
        soc_values.append(soc)
    
    # Calculate statistics
    mean_soc = np.mean(soc_values)
    std_soc = np.std(soc_values)
    
    print(f"SOC Statistics:")
    print(f"  Mean: {mean_soc:.2f} | Std: {std_soc:.2f}")
    
    return {'mean': mean_soc, 'std': std_soc, 'values': soc_values}
    

def calculate_statistics(values):
    """Calculate mean and standard deviation for a list of values."""
    if not values:
        return 0.0, 0.0
    return np.mean(values), np.std(values)

def parse_individual_tokens(step_tokens):
    """Parse tokens for individual case where each step contains multiple agents."""
    step_input = 0
    step_output = 0
    step_executor_input = 0
    step_executor_output = 0
    
    print(f"Parsing individual tokens: {step_tokens}")
    
    # Each step is a dictionary containing multiple agents
    for agent_id, agent_tokens in step_tokens.items():
        print(f"Processing agent {agent_id} with tokens: {agent_tokens}")
        if isinstance(agent_tokens, dict) and 'reasoning' in agent_tokens and 'actions' in agent_tokens:
            # Sum up both reasoning and actions tokens
            agent_input = agent_tokens['reasoning']['input'] + agent_tokens['actions']['input']
            agent_output = agent_tokens['reasoning']['output'] + agent_tokens['actions']['output']
            
            print(f"Agent {agent_id} - input: {agent_input}, output: {agent_output}")
            
            # Add to step totals
            step_input += agent_input
            step_output += agent_output
            
            # For individual case, all agents are executors
            step_executor_input += agent_input
            step_executor_output += agent_output
    
    print(f"Step totals - input: {step_input}, output: {step_output}")
    return step_input, step_output, step_executor_input, step_executor_output

def analyze_results(levels, results):
    # Create nested defaultdict for storing results
    level_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    all_levels_stats = defaultdict(lambda: defaultdict(list))
    all_levels_noop_stats = defaultdict(lambda: defaultdict(int))
    all_levels_occupation_stats = defaultdict(lambda: defaultdict(list))
    all_levels_action_stats = {'goto': 0, 'get': 0, 'put': 0, 'noop': 0, 'activate': 0}  # Regular dict instead of defaultdict
    
    for level, data in zip(levels, results):
        print(f"\nLevel {level}:")
        print("=" * 50)
        print(f"Data keys: {data.keys()}")
        for alpha in data:
            print(f"\nAlpha: {alpha}")
            print(f"Data[alpha] type: {type(data[alpha])}")
            print(f"Data[alpha] length: {len(data[alpha])}")
            if data[alpha]:
                print(f"First episode keys: {data[alpha][0].keys()}")
                if 'tokens' in data[alpha][0]:
                    print(f"First episode tokens type: {type(data[alpha][0]['tokens'])}")
                    print(f"First episode tokens length: {len(data[alpha][0]['tokens'])}")
                    if data[alpha][0]['tokens']:
                        print(f"First step tokens: {data[alpha][0]['tokens'][0]}")
            # Store task statistics
            task_stats = print_task_statistics(data, alpha)
            total_tasks = 0
            for task, stats in task_stats.items():
                # Store the mean and std for each task
                level_stats[level]['task'][f"{task}_mean"].append(stats['mean'])
                level_stats[level]['task'][f"{task}_std"].append(stats['std'])
                # Store for all-levels aggregate
                all_levels_stats['task'][task].append(stats['mean'])
                total_tasks += stats['mean']
            
            # Store total tasks
            level_stats[level]['task']['total_mean'].append(total_tasks)
            all_levels_stats['task']['total'].append(total_tasks)
            
            # Store noop statistics
            noop_stats = print_noop_statistics(data, alpha)
            level_stats[level]['noop']['total_mean'].append(noop_stats['total']['mean'])
            level_stats[level]['noop']['total_std'].append(noop_stats['total']['std'])
            level_stats[level]['noop']['total_sum'].append(noop_stats['total']['total'])
            
            # Store for all-levels aggregate
            all_levels_stats['noop']['mean'].append(noop_stats['total']['mean'])
            all_levels_stats['noop']['total'].append(noop_stats['total']['total'])
            
            # Store per-agent noop stats in all-levels aggregate
            for agent_id, stats in noop_stats['per_agent'].items():
                all_levels_noop_stats['per_agent'][agent_id] += stats['total']

            # Collect action type statistics from action history
            for eps_data in data[alpha]:
                if 'action_history' in eps_data:
                    for episode_actions in eps_data['action_history']:  # Each episode has a list of steps
                        for step_actions in episode_actions:  # Each step has actions for each agent
                            for action in step_actions:  # Each action is for a specific agent
                                # Extract action type from the action string (e.g., 'goto' from 'goto_agent1_storage0')
                                action_type = action.split('_')[0]
                                if action_type in all_levels_action_stats:
                                    all_levels_action_stats[action_type] += 1

            # Store social outcome statistics
            soc_stats = print_soc_statistics(data, alpha)
            level_stats[level]['social']['mean'].append(soc_stats['mean'])
            level_stats[level]['social']['std'].append(soc_stats['std'])
            # Store for all-levels aggregate
            all_levels_stats['social']['soc'].append(soc_stats['mean'])

            # Collect stop_step values
            stop_steps = [eps_data['stop_step'] for eps_data in data[alpha]]
            level_stats[level]['stop_step']['values'].extend(stop_steps)
            all_levels_stats['stop_step']['values'].extend(stop_steps)

            # Collect price values
            prices = [eps_data['price'] for eps_data in data[alpha]]
            level_stats[level]['price']['values'].extend(prices)
            all_levels_stats['price']['values'].extend(prices)

            # Collect token values
            input_tokens = []
            output_tokens = []
            planner_input_tokens = []
            planner_output_tokens = []
            executor_input_tokens = []
            executor_output_tokens = []
            
            for eps_data in data[alpha]:
                print(f"\nProcessing episode data: {eps_data.keys()}")
                if 'tokens' in eps_data:
                    print(f"Found tokens in episode data: {len(eps_data['tokens'])} steps")
                    print(f"First step tokens: {eps_data['tokens'][0] if eps_data['tokens'] else 'No tokens'}")
                    
                    # For individual case, the entire list is the tokens
                    if args.agent_type == 'individual':
                        print("Processing individual case tokens")
                        for step_dict in eps_data['tokens']:
                            print(f"Processing step dict: {step_dict}")
                            if isinstance(step_dict, dict):
                                step_input, step_output, step_executor_input, step_executor_output = parse_individual_tokens(step_dict)
                                input_tokens.append(step_input)
                                output_tokens.append(step_output)
                                executor_input_tokens.append(step_executor_input)
                                executor_output_tokens.append(step_executor_output)
                    else:
                        # For other cases, process each step
                        for step_idx, step_tokens in enumerate(eps_data['tokens']):
                            print(f"\nProcessing step {step_idx}")
                            step_input = 0
                            step_output = 0
                            step_planner_input = 0
                            step_planner_output = 0
                            step_executor_input = 0
                            step_executor_output = 0
                            
                            # Handle different token structures
                            if isinstance(step_tokens, list):
                                print(f"Processing list structure with {len(step_tokens)} items")
                                # For individual case, each step is a dictionary containing multiple agents
                                for step_dict in step_tokens:
                                    print(f"Processing step dict: {step_dict}")
                                    if isinstance(step_dict, dict):
                                        step_input, step_output, step_executor_input, step_executor_output = parse_individual_tokens(step_dict)
                            elif isinstance(step_tokens, dict):
                                print(f"Processing dict structure with keys: {step_tokens.keys()}")
                                if 'reasoning' in step_tokens and 'actions' in step_tokens:
                                    # Orchestrator structure
                                    step_input = step_tokens['reasoning']['input'] + step_tokens['actions']['input']
                                    step_output = step_tokens['reasoning']['output'] + step_tokens['actions']['output']
                                else:
                                    # Planner structure
                                    for agent_id, agent_tokens in step_tokens.items():
                                        print(f"Processing agent {agent_id} with tokens: {agent_tokens}")
                                        if isinstance(agent_tokens, dict) and 'input' in agent_tokens and 'output' in agent_tokens:
                                            if agent_id == 'planner':
                                                step_planner_input += agent_tokens['input']
                                                step_planner_output += agent_tokens['output']
                                            else:
                                                step_executor_input += agent_tokens['input']
                                                step_executor_output += agent_tokens['output']
                                            step_input += agent_tokens['input']
                                            step_output += agent_tokens['output']
                            
                            print(f"Step {step_idx} totals:")
                            print(f"  Input: {step_input}")
                            print(f"  Output: {step_output}")
                            print(f"  Executor Input: {step_executor_input}")
                            print(f"  Executor Output: {step_executor_output}")
                            
                            input_tokens.append(step_input)
                            output_tokens.append(step_output)
                            planner_input_tokens.append(step_planner_input)
                            planner_output_tokens.append(step_planner_output)
                            executor_input_tokens.append(step_executor_input)
                            executor_output_tokens.append(step_executor_output)
                else:
                    print("No tokens found in episode data")
            
            print(f"\nEpisode totals:")
            print(f"Total input tokens: {sum(input_tokens)}")
            print(f"Total output tokens: {sum(output_tokens)}")
            print(f"Total executor input tokens: {sum(executor_input_tokens)}")
            print(f"Total executor output tokens: {sum(executor_output_tokens)}")
            
            level_stats[level]['tokens']['input'].extend(input_tokens)
            level_stats[level]['tokens']['output'].extend(output_tokens)
            level_stats[level]['tokens']['planner_input'].extend(planner_input_tokens)
            level_stats[level]['tokens']['planner_output'].extend(planner_output_tokens)
            level_stats[level]['tokens']['executor_input'].extend(executor_input_tokens)
            level_stats[level]['tokens']['executor_output'].extend(executor_output_tokens)
            
            all_levels_stats['tokens']['input'].extend(input_tokens)
            all_levels_stats['tokens']['output'].extend(output_tokens)
            all_levels_stats['tokens']['planner_input'].extend(planner_input_tokens)
            all_levels_stats['tokens']['planner_output'].extend(planner_output_tokens)
            all_levels_stats['tokens']['executor_input'].extend(executor_input_tokens)
            all_levels_stats['tokens']['executor_output'].extend(executor_output_tokens)

        print('-'*25)
    
    # Print per-level statistics
    print("\nPer-Level Statistics:")
    print("=" * 50)
    
    for level in level_stats:
        print(f"\nLevel {level}:")
        print("-" * 25)
        
        # Task statistics
        print("\nTask Statistics:")
        for metric in level_stats[level]['task']:
            values = level_stats[level]['task'][metric]
            mean, std = calculate_statistics(values)
            print(f"{metric:20} - Mean: {mean:.3f}, Std: {std:.3f}")
        
        # Social statistics
        print("\nSoC Statistics:")
        for metric in level_stats[level]['social']:
            values = level_stats[level]['social'][metric]
            mean, std = calculate_statistics(values)
            print(f"{metric:20} - Mean: {mean:.3f}, Std: {std:.3f}")

        # Print stop_step statistics
        print("\nStop Step Statistics:")
        values = level_stats[level]['stop_step']['values']
        mean, std = calculate_statistics(values)
        print(f"{'Stop Step':20} - Mean: {mean:.3f}, Std: {std:.3f}")

        # Print price statistics
        print("\nPrice Statistics:")
        values = level_stats[level]['price']['values']
        mean, std = calculate_statistics(values)
        total_price = sum(values)
        print(f"{'Price':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_price:.3f}")

        # Print token statistics
        print("\nToken Statistics:")
        input_values = level_stats[level]['tokens']['input']
        output_values = level_stats[level]['tokens']['output']
        input_mean, input_std = calculate_statistics(input_values)
        output_mean, output_std = calculate_statistics(output_values)
        total_input = sum(input_values)
        total_output = sum(output_values)
        print(f"{'Input Tokens':20} - Mean: {input_mean:.3f}, Std: {input_std:.3f}, Sum: {total_input}")
        print(f"{'Output Tokens':20} - Mean: {output_mean:.3f}, Std: {output_std:.3f}, Sum: {total_output}")
        
        # Print planner/executor token statistics
        planner_input_values = level_stats[level]['tokens']['planner_input']
        planner_output_values = level_stats[level]['tokens']['planner_output']
        executor_input_values = level_stats[level]['tokens']['executor_input']
        executor_output_values = level_stats[level]['tokens']['executor_output']
        
        planner_input_mean, planner_input_std = calculate_statistics(planner_input_values)
        planner_output_mean, planner_output_std = calculate_statistics(planner_output_values)
        executor_input_mean, executor_input_std = calculate_statistics(executor_input_values)
        executor_output_mean, executor_output_std = calculate_statistics(executor_output_values)
        
        total_planner_input = sum(planner_input_values)
        total_planner_output = sum(planner_output_values)
        total_executor_input = sum(executor_input_values)
        total_executor_output = sum(executor_output_values)
        
        print(f"{'Planner Input Tokens':20} - Mean: {planner_input_mean:.3f}, Std: {planner_input_std:.3f}, Sum: {total_planner_input}")
        print(f"{'Planner Output Tokens':20} - Mean: {planner_output_mean:.3f}, Std: {planner_output_std:.3f}, Sum: {total_planner_output}")
        print(f"{'Executor Input Tokens':20} - Mean: {executor_input_mean:.3f}, Std: {executor_input_std:.3f}, Sum: {total_executor_input}")
        print(f"{'Executor Output Tokens':20} - Mean: {executor_output_mean:.3f}, Std: {executor_output_std:.3f}, Sum: {total_executor_output}")

    # Print aggregate statistics across ALL levels
    print("\n" + "=" * 50)
    print("AGGREGATE STATISTICS ACROSS ALL LEVELS")
    print("=" * 50)
    
    print("\nTask Statistics:")
    # First print total tasks
    values = all_levels_stats['task']['total']
    mean, std = calculate_statistics(values)
    total_sum = sum(values)
    print(f"{'Total Tasks':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_sum:.1f}")
    print("-" * 60)
    # Then print individual tasks
    for task in sorted(all_levels_stats['task'].keys()):
        if task != 'total':
            values = all_levels_stats['task'][task]
            mean, std = calculate_statistics(values)
            task_sum = sum(values)
            print(f"{task:20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {task_sum:.1f}")

    print("\nSoC Statistics:")
    values = all_levels_stats['social']['soc']
    mean, std = calculate_statistics(values)
    print(f"{'Success Over Completion':20} - Mean: {mean:.3f}, Std: {std:.3f}")
   
    # Print aggregate stop_step statistics
    print("\nStop Step Statistics:")
    values = all_levels_stats['stop_step']['values']
    mean, std = calculate_statistics(values)
    print(f"{'Stop Step':20} - Mean: {mean:.3f}, Std: {std:.3f}")

    # Print aggregate price statistics
    print("\nPrice Statistics:")
    values = all_levels_stats['price']['values']
    mean, std = calculate_statistics(values)
    total_price = sum(values)
    print(f"{'Price':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_price:.3f}")

    # Print aggregate token statistics
    print("\nToken Statistics:")
    input_values = all_levels_stats['tokens']['input']
    output_values = all_levels_stats['tokens']['output']
    input_mean, input_std = calculate_statistics(input_values)
    output_mean, output_std = calculate_statistics(output_values)
    total_input = sum(input_values)
    total_output = sum(output_values)
    print(f"{'Input Tokens':20} - Mean: {input_mean:.3f}, Std: {input_std:.3f}, Sum: {total_input}")
    print(f"{'Output Tokens':20} - Mean: {output_mean:.3f}, Std: {output_std:.3f}, Sum: {total_output}")

    # Print aggregate planner/executor token statistics
    planner_input_values = all_levels_stats['tokens']['planner_input']
    planner_output_values = all_levels_stats['tokens']['planner_output']
    executor_input_values = all_levels_stats['tokens']['executor_input']
    executor_output_values = all_levels_stats['tokens']['executor_output']
    
    planner_input_mean, planner_input_std = calculate_statistics(planner_input_values)
    planner_output_mean, planner_output_std = calculate_statistics(planner_output_values)
    executor_input_mean, executor_input_std = calculate_statistics(executor_input_values)
    executor_output_mean, executor_output_std = calculate_statistics(executor_output_values)
    
    total_planner_input = sum(planner_input_values)
    total_planner_output = sum(planner_output_values)
    total_executor_input = sum(executor_input_values)
    total_executor_output = sum(executor_output_values)
    
    print(f"{'Planner Input Tokens':20} - Mean: {planner_input_mean:.3f}, Std: {planner_input_std:.3f}, Sum: {total_planner_input}")
    print(f"{'Planner Output Tokens':20} - Mean: {planner_output_mean:.3f}, Std: {planner_output_std:.3f}, Sum: {total_planner_output}")
    print(f"{'Executor Input Tokens':20} - Mean: {executor_input_mean:.3f}, Std: {executor_input_std:.3f}, Sum: {total_executor_input}")
    print(f"{'Executor Output Tokens':20} - Mean: {executor_output_mean:.3f}, Std: {executor_output_std:.3f}, Sum: {total_executor_output}")

    # Print aggregate noop statistics
    print("\nNoop Count Statistics:")
    total_mean = np.mean(all_levels_stats['noop']['mean']) if all_levels_stats['noop']['mean'] else 0
    total_sum = sum(all_levels_stats['noop']['total']) if all_levels_stats['noop']['total'] else 0
    print(f"{'Total Noop Count':20} - Mean: {total_mean:.3f}, Sum: {total_sum}")
    
    # Print per-agent noop statistics
    if all_levels_noop_stats['per_agent']:
        print("\nPer-Agent Noop Count Totals:")
        for agent_id, total in sorted(all_levels_noop_stats['per_agent'].items()):
            print(f"Agent {agent_id}: {total}")

    # Print action type statistics
    print("\nAction Type Statistics:")
    print("-" * 60)
    for action_type in ['goto', 'get', 'put', 'noop', 'activate']:
        count = all_levels_action_stats.get(action_type, 0)
        print(f"{action_type:20} - Total: {count}")

    return level_stats, all_levels_stats, all_levels_noop_stats, all_levels_action_stats

def main(args): 
    results, levels = read_results(args)
    
    # accomplished tasks
    level_stats, all_levels_stats, all_levels_noop_stats, all_levels_action_stats = analyze_results(levels, results)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluation")

    parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
    parser.add_argument('--model', type=str, default=AVAILABLE_MODELS[0], choices=AVAILABLE_MODELS, help='Model to use')
    parser.add_argument('--executor_model', type=str, default=AVAILABLE_MODELS[2], choices=AVAILABLE_MODELS, help='Executor model to use')
    parser.add_argument('--agent_type', type=str, default='individual', choices=AGENT_TYPES, help='Type of agent')
    parser.add_argument('--num_agents', type=int, default=1, help='Number of agents')
    parser.add_argument('--level', type=str, default='all', choices=LEVELS+['all'], help='Level to run')
    parser.add_argument('--budget', type=int, default=None, help='budget')
    parser.add_argument('--ablation', action='store_true', help='Ablation')

    args = parser.parse_args()
    
    main(args)