import json
import argparse
from pathlib import Path

from collections import Counter, defaultdict
import numpy as np

AVAILABLE_MODELS = [
    "Meta-Llama-3.1-8B-Instruct",
    "Llama-3.3-70B-Instruct",
    "Llama-3.1-70B-Instruct",
    "Meta-Llama-3.1-405B-Instruct-FP8",
    "Qwen2.5-32B-Instruct",
    "gpt-4o-v2"
]

LEVELS = [f'level_{i}' for i in range(13)]

AGENT_TYPES = [
    'individual', 
    'debate',
    'orchestrator',
    'planner'
    ]

def read_results(args):

    if args.level == 'all':
        levels = LEVELS
    else:
        levels = [args.level]

    files = []
    for level in levels:
        if args.budget != None:
            filename = Path(args.results_dir) / args.model / args.agent_type / str(args.num_agents) / f'result_{level}_{args.budget}.json'
        else:
            filename = Path(args.results_dir) / args.model / args.agent_type / str(args.num_agents) / f'result_{level}.json'
        files.append(filename)

    results, existing_is = [], []
    for i, file in enumerate(files):
        print(f"Loading results from {file}")
        print("Exists? : ", file.exists())
        if file.exists():
            existing_is.append(i)
            with open(file, 'r') as f:
                results.append(json.load(f))
            
    return results, [levels[i] for i in existing_is]

def print_task_statistics(data, alpha):
    """
    Calculate and print statistics for accomplished tasks across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Initialize a list to store Counters from all episodes
    all_counters = []

    # Collect counters from all episodes
    for eps_data in data[alpha]:
        all_counters.append(Counter(eps_data['acomplished_task_list']))

    # Combine all unique tasks
    all_tasks = set()
    for counter in all_counters:
        all_tasks.update(counter.keys())

    # Create arrays for each task
    task_stats = {}
    for task in all_tasks:
        # Get counts for this task across all episodes (0 if task not present in episode)
        counts = [counter.get(task, 0) for counter in all_counters]
        
        task_stats[task] = {
            'mean': np.mean(counts),
            'std': np.std(counts),
            'counts': counts
        }

    # Print results
    print("\nTask Statistics Across Episodes:")
    print("-" * 50)
    for task, stats in task_stats.items():
        print(f"{task}:")
        print(f"  Mean: {stats['mean']:.2f} | Std: {stats['std']:.2f}")

    return task_stats

def print_noop_statistics(data, alpha):
    """
    Calculate and print noop count statistics per agent across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Initialize dictionary to store noop counts per agent
    agent_noop_counts = defaultdict(list)
    total_noop_counts = []
    
    # Collect noop counts from all episodes
    for eps_data in data[alpha]:
        episode_total = 0
        
        if 'agent_noop_counts' in eps_data:
            # If the data has per-agent noop counts dictionary
            for agent_id, count in eps_data['agent_noop_counts'].items():
                if isinstance(count, (int, float)):
                    agent_noop_counts[agent_id].append(count)
                    episode_total += count
                elif isinstance(count, list):
                    # If it's a list, sum it first
                    agent_noop_counts[agent_id].append(sum(count))
                    episode_total += sum(count)
            
        elif 'noop_count' in eps_data:
            # If the data has a noop_count field
            if isinstance(eps_data['noop_count'], (int, float)):
                # Single value - add to total
                episode_total = eps_data['noop_count']
                # We don't know which agent, so we can't add to agent_noop_counts
            
            elif isinstance(eps_data['noop_count'], list):
                # List of counts per agent
                for i, count in enumerate(eps_data['noop_count']):
                    agent_id = str(i)  # Use index as agent ID
                    agent_noop_counts[agent_id].append(count)
                    episode_total += count
        
        # Add episode total to the list of totals
        total_noop_counts.append(episode_total)
    
    # Print results
    print("\nNoop Count Statistics:")
    print("-" * 50)
    
    # Print per-agent statistics if available
    if agent_noop_counts:
        for agent_id, counts in sorted(agent_noop_counts.items()):
            mean = np.mean(counts)
            std = np.std(counts)
            total = sum(counts)
            print(f"Agent {agent_id}:")
            print(f"  Mean: {mean:.2f} | Std: {std:.2f} | Total: {total}")
    
    # Print total statistics
    if total_noop_counts:
        mean = np.mean(total_noop_counts)
        std = np.std(total_noop_counts)
        total = sum(total_noop_counts)
        print(f"Total Noop Count:")
        print(f"  Mean: {mean:.2f} | Std: {std:.2f} | Total: {total}")
    
    return {
        'per_agent': {agent_id: {'mean': np.mean(counts), 'std': np.std(counts), 'total': sum(counts)} 
                     for agent_id, counts in agent_noop_counts.items()},
        'total': {'mean': np.mean(total_noop_counts) if total_noop_counts else 0, 
                 'std': np.std(total_noop_counts) if total_noop_counts else 0,
                 'total': sum(total_noop_counts) if total_noop_counts else 0}
    }

def print_soc_statistics(data, alpha):
    """
    Calculate and print SOC (Success Over Completion) statistics across episodes.
    
    Args:
        data (dict): Dictionary containing episode data
        alpha (str): Alpha value being analyzed
    """
    # Calculate SOC for each episode
    soc_values = []
    for eps_data in data[alpha]:
        if eps_data['success'] + eps_data['failed'] == 0:
            soc = 0
        else:
            soc = eps_data['success'] / (eps_data['success'] + eps_data['failed'])
        soc_values.append(soc)
    
    # Calculate statistics
    mean_soc = np.mean(soc_values)
    std_soc = np.std(soc_values)
    
    print(f"SOC Statistics:")
    print(f"  Mean: {mean_soc:.2f} | Std: {std_soc:.2f}")
    
    return {'mean': mean_soc, 'std': std_soc, 'values': soc_values}
    

def calculate_statistics(values):
    """Calculate mean and standard deviation for a list of values."""
    if not values:
        return 0.0, 0.0
    return np.mean(values), np.std(values)

def analyze_results(levels, results):
    # Create nested defaultdict for storing results
    level_stats = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    all_levels_stats = defaultdict(lambda: defaultdict(list))
    all_levels_noop_stats = defaultdict(lambda: defaultdict(int))
    all_levels_occupation_stats = defaultdict(lambda: defaultdict(list))
    
    for level, data in zip(levels, results):
        print(f"\nLevel {level}:")
        print("=" * 50)
        for alpha in data:
            print(f"\nAlpha: {alpha}")
            # Store task statistics
            task_stats = print_task_statistics(data, alpha)
            total_tasks = 0
            for task, stats in task_stats.items():
                # Store the mean and std for each task
                level_stats[level]['task'][f"{task}_mean"].append(stats['mean'])
                level_stats[level]['task'][f"{task}_std"].append(stats['std'])
                # Store for all-levels aggregate
                all_levels_stats['task'][task].append(stats['mean'])
                total_tasks += stats['mean']
            
            # Store total tasks
            level_stats[level]['task']['total_mean'].append(total_tasks)
            all_levels_stats['task']['total'].append(total_tasks)
            
            # Store noop statistics
            noop_stats = print_noop_statistics(data, alpha)
            level_stats[level]['noop']['total_mean'].append(noop_stats['total']['mean'])
            level_stats[level]['noop']['total_std'].append(noop_stats['total']['std'])
            level_stats[level]['noop']['total_sum'].append(noop_stats['total']['total'])
            
            # Store for all-levels aggregate
            all_levels_stats['noop']['mean'].append(noop_stats['total']['mean'])
            all_levels_stats['noop']['total'].append(noop_stats['total']['total'])
            
            # Store per-agent noop stats in all-levels aggregate
            for agent_id, stats in noop_stats['per_agent'].items():
                all_levels_noop_stats['per_agent'][agent_id] += stats['total']
            

            # Store social outcome statistics
            soc_stats = print_soc_statistics(data, alpha)
            level_stats[level]['social']['mean'].append(soc_stats['mean'])
            level_stats[level]['social']['std'].append(soc_stats['std'])
            # Store for all-levels aggregate
            all_levels_stats['social']['soc'].append(soc_stats['mean'])

            # Collect stop_step values
            stop_steps = [eps_data['stop_step'] for eps_data in data[alpha]]
            level_stats[level]['stop_step']['values'].extend(stop_steps)
            all_levels_stats['stop_step']['values'].extend(stop_steps)

            # Collect price values
            prices = [eps_data['price'] for eps_data in data[alpha]]
            level_stats[level]['price']['values'].extend(prices)
            all_levels_stats['price']['values'].extend(prices)

        print('-'*25)
    
    # Print per-level statistics
    print("\nPer-Level Statistics:")
    print("=" * 50)
    
    for level in level_stats:
        print(f"\nLevel {level}:")
        print("-" * 25)
        
        # Task statistics
        print("\nTask Statistics:")
        for metric in level_stats[level]['task']:
            values = level_stats[level]['task'][metric]
            mean, std = calculate_statistics(values)
            print(f"{metric:20} - Mean: {mean:.3f}, Std: {std:.3f}")
        
        # Social statistics
        print("\nSoC Statistics:")
        for metric in level_stats[level]['social']:
            values = level_stats[level]['social'][metric]
            mean, std = calculate_statistics(values)
            print(f"{metric:20} - Mean: {mean:.3f}, Std: {std:.3f}")

        # Print stop_step statistics
        print("\nStop Step Statistics:")
        values = level_stats[level]['stop_step']['values']
        mean, std = calculate_statistics(values)
        print(f"{'Stop Step':20} - Mean: {mean:.3f}, Std: {std:.3f}")

        # Print price statistics
        print("\nPrice Statistics:")
        values = level_stats[level]['price']['values']
        mean, std = calculate_statistics(values)
        total_price = sum(values)
        print(f"{'Price':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_price:.3f}")

    # Print aggregate statistics across ALL levels
    print("\n" + "=" * 50)
    print("AGGREGATE STATISTICS ACROSS ALL LEVELS")
    print("=" * 50)
    
    print("\nTask Statistics:")
    # First print total tasks
    values = all_levels_stats['task']['total']
    mean, std = calculate_statistics(values)
    total_sum = sum(values)
    print(f"{'Total Tasks':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_sum:.1f}")
    print("-" * 60)
    # Then print individual tasks
    for task in sorted(all_levels_stats['task'].keys()):
        if task != 'total':
            values = all_levels_stats['task'][task]
            mean, std = calculate_statistics(values)
            task_sum = sum(values)
            print(f"{task:20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {task_sum:.1f}")

    print("\nSoC Statistics:")
    values = all_levels_stats['social']['soc']
    mean, std = calculate_statistics(values)
    print(f"{'Success Over Completion':20} - Mean: {mean:.3f}, Std: {std:.3f}")
   
    # Print aggregate stop_step statistics
    print("\nStop Step Statistics:")
    values = all_levels_stats['stop_step']['values']
    mean, std = calculate_statistics(values)
    print(f"{'Stop Step':20} - Mean: {mean:.3f}, Std: {std:.3f}")

    # Print aggregate price statistics
    print("\nPrice Statistics:")
    values = all_levels_stats['price']['values']
    mean, std = calculate_statistics(values)
    total_price = sum(values)
    print(f"{'Price':20} - Mean: {mean:.3f}, Std: {std:.3f}, Sum: {total_price:.3f}")

    # Print aggregate noop statistics
    print("\nNoop Count Statistics:")
    total_mean = np.mean(all_levels_stats['noop']['mean']) if all_levels_stats['noop']['mean'] else 0
    total_sum = sum(all_levels_stats['noop']['total']) if all_levels_stats['noop']['total'] else 0
    print(f"{'Total Noop Count':20} - Mean: {total_mean:.3f}, Sum: {total_sum}")
    
    # Print per-agent noop statistics
    if all_levels_noop_stats['per_agent']:
        print("\nPer-Agent Noop Count Totals:")
        for agent_id, total in sorted(all_levels_noop_stats['per_agent'].items()):
            print(f"Agent {agent_id}: {total}")

    return level_stats, all_levels_stats, all_levels_noop_stats

def main(args): 
    results, levels = read_results(args)
    
    # accomplished tasks
    level_stats, all_levels_stats, all_levels_noop_stats = analyze_results(levels, results)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluation")

    parser.add_argument('--results_dir', type=str, default='results', help='Directory to save results')
    parser.add_argument('--model', type=str, default=AVAILABLE_MODELS[0], choices=AVAILABLE_MODELS, help='Model to use')
    parser.add_argument('--agent_type', type=str, default='individual', choices=AGENT_TYPES, help='Type of agent')
    parser.add_argument('--num_agents', type=int, default=1, help='Number of agents')
    parser.add_argument('--level', type=str, default='all', choices=LEVELS+['all'], help='Level to run')
    parser.add_argument('--budget', type=int, default=None, help='budget')

    args = parser.parse_args()
    
    main(args)