#!/usr/bin/env python3
import os
import json
import glob
import numpy as np
import math
import sys
import csv
from collections import defaultdict

# Define the GLUE tasks in the order they should be displayed
TASKS = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb"]

# Make sure all tasks are included even if they have no data
ALL_TASKS = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb"]

SEEDS = [42, 2025, 777]

# Define primary metrics for GLUE score calculation
PRIMARY_TASK_METRICS = {
    "mnli": "eval_combined_score",
    "qqp": "eval_accuracy",
    "qnli": "eval_accuracy",
    "sst2": "eval_accuracy",
    "cola": "eval_matthews_correlation",
    "mrpc": "eval_accuracy", 
    "rte": "eval_accuracy",
    "stsb": "eval_pearson"
}

def find_latest_output_dir(task, seed):
    """Find the latest output directory for a given task and seed."""
    pattern = f"./output/{task}_optimal_lora_seed{seed}_*"
    dirs = sorted(glob.glob(pattern))
    return dirs[-1] if dirs else None

def extract_metrics_from_json(json_file):
    """Extract metrics from a JSON file."""
    try:
        with open(json_file, 'r') as f:
            return json.load(f)
    except (json.JSONDecodeError, FileNotFoundError) as e:
        print(f"Error reading {json_file}: {e}", file=sys.stderr)
        return {}

def format_number(value):
    """Format a number for display."""
    if value is None or (isinstance(value, float) and math.isnan(value)):
        return "N/A"
    
    if isinstance(value, str):
        try:
            value = float(value)
        except ValueError:
            return value
    
    # Format based on magnitude
    if abs(value) >= 1000:
        return f"{value:.2f}"
    elif abs(value) >= 1:
        return f"{value:.4f}"
    else:
        return f"{value:.4f}"  # Use consistent decimal places for all small numbers

def calculate_statistics(values):
    """Calculate mean and standard deviation."""
    # Convert to float and filter invalid values
    valid_values = []
    for val in values:
        if val is not None:
            try:
                valid_values.append(float(val))
            except (ValueError, TypeError):
                pass
    
    if not valid_values:
        return None, None
    
    try:
        mean = np.mean(valid_values)
        stddev = np.std(valid_values, ddof=1) if len(valid_values) > 1 else 0
        return mean, stddev
    except Exception as e:
        print(f"Error calculating statistics: {e}", file=sys.stderr)
        return None, None

def main():
    # Collect metrics for all tasks and seeds
    all_metrics = defaultdict(lambda: defaultdict(dict))
    all_metric_names = set()
    
    # First collect all metric names from all tasks
    for task in ALL_TASKS:
        print(f"Processing {task}...")
        for seed in SEEDS:
            output_dir = find_latest_output_dir(task, seed)
            if output_dir:
                metrics_file = os.path.join(output_dir, "all_metrics.json")
                metrics = extract_metrics_from_json(metrics_file)
                
                if metrics:
                    for metric_name, metric_value in metrics.items():
                        all_metrics[task][metric_name][seed] = metric_value
                        all_metric_names.add(metric_name)
                        print(f"  Extracted {task}/{metric_name}/{seed} = {metric_value}")
            else:
                print(f"No output directory found for {task} with seed {seed}")
    
    # Reorganize data by metric with statistics for each task
    metric_stats = {}
    
    for metric_name in sorted(all_metric_names):
        metric_stats[metric_name] = {}
        
        for task in ALL_TASKS:
            if metric_name in all_metrics[task]:
                values = [all_metrics[task][metric_name].get(seed) for seed in SEEDS]
                mean, stddev = calculate_statistics(values)
                
                seed42_value = all_metrics[task][metric_name].get(42, None)
                seed2025_value = all_metrics[task][metric_name].get(2025, None)
                seed777_value = all_metrics[task][metric_name].get(777, None)
                
                metric_stats[metric_name][task] = {
                    'mean': mean,
                    'stddev': stddev,
                    'seeds': {
                        42: seed42_value,
                        2025: seed2025_value,
                        777: seed777_value
                    }
                }
            else:
                # No data for this task
                metric_stats[metric_name][task] = {
                    'mean': None,
                    'stddev': None,
                    'seeds': {42: None, 2025: None, 777: None}
                }
    
    # Determine available tasks (tasks with actual data)
    available_tasks = []
    for task in ALL_TASKS:
        if any(task in metric_stats[metric] for metric in metric_stats):
            available_tasks.append(task)
    
    # Create CSV file with the same format (metrics as rows, tasks as columns)
    csv_file_path = "results_summary/metrics_summary.csv"
    with open(csv_file_path, 'w', newline='') as csvfile:
        # Create CSV writer
        writer = csv.writer(csvfile)
        
        # Write header with task names
        header = ["metric"] + available_tasks
        writer.writerow(header)
        
        # Write each metric's data
        for metric_name in sorted(metric_stats.keys()):
            readable_metric = metric_name.replace("eval_", "").replace("_", " ")
            
            # Write row for mean values
            mean_row = [f"{readable_metric}_mean"]
            for task in available_tasks:
                mean = metric_stats[metric_name].get(task, {}).get('mean')
                mean_row.append(format_number(mean))
            writer.writerow(mean_row)
            
            # Write row for standard deviation values
            stddev_row = [f"{readable_metric}_stddev"]
            for task in available_tasks:
                stddev = metric_stats[metric_name].get(task, {}).get('stddev')
                stddev_row.append(format_number(stddev))
            writer.writerow(stddev_row)
            
            # Write rows for individual seed values
            for seed in SEEDS:
                seed_row = [f"{readable_metric}_{seed}"]
                for task in available_tasks:
                    seed_value = metric_stats[metric_name].get(task, {}).get('seeds', {}).get(seed)
                    seed_row.append(format_number(seed_value))
                writer.writerow(seed_row)
    
    print(f"CSV summary saved to {csv_file_path}")
    
    # Create formatted text summary (metrics as rows, tasks as columns)
    with open("results_summary/formatted_summary.txt", 'w') as f:
        # Determine column widths for each task
        task_column_width = 14  # Width of each task column
        metric_column_width = 24  # Width of the metric name column
        
        # Write header
        f.write("=" * (metric_column_width + len(available_tasks) * (task_column_width + 1)) + "\n")
        f.write("TASK PERFORMANCE SUMMARY ACROSS SEEDS (METRIC ROWS, TASK COLUMNS)\n")
        f.write("=" * (metric_column_width + len(available_tasks) * (task_column_width + 1)) + "\n")
        
        # Column headers
        f.write(f"{'METRIC':{metric_column_width}}")
        for task in available_tasks:
            f.write(f"|{task.upper():^{task_column_width}}")
        f.write("\n")
        
        # Separator line
        f.write("-" * metric_column_width)
        for _ in available_tasks:
            f.write("+" + "-" * task_column_width)
        f.write("\n")
        
        # Write each metric as a row
        for metric_name in sorted(metric_stats.keys()):
            readable_metric = metric_name.replace("eval_", "").replace("_", " ")
            f.write(f"{readable_metric:{metric_column_width}}")
            
            for task in available_tasks:
                stats = metric_stats[metric_name].get(task, {})
                mean = stats.get('mean')
                stddev = stats.get('stddev')
                
                # Format the mean and stddev
                if mean is not None and stddev is not None:
                    mean_stddev_str = f"{format_number(mean)}±{format_number(stddev)}"
                else:
                    mean_stddev_str = "N/A"
                
                f.write(f"|{mean_stddev_str:^{task_column_width}}")
            f.write("\n")
            
            # Write seed values on the next three lines
            for seed_line_num, seed in enumerate(SEEDS):
                prefix = "  " if seed_line_num == 0 else "  "
                f.write(f"{prefix:{metric_column_width}}")
                
                for task in available_tasks:
                    seed_value = metric_stats[metric_name].get(task, {}).get('seeds', {}).get(seed)
                    seed_str = format_number(seed_value) if seed_value is not None else "N/A"
                    if seed_line_num == 0:
                        seed_str = f"[{seed}: {seed_str}"
                    elif seed_line_num == len(SEEDS) - 1:
                        seed_str = f"{seed}: {seed_str}]"
                    else:
                        seed_str = f"{seed}: {seed_str}"
                    
                    f.write(f"|{seed_str:^{task_column_width}}")
                f.write("\n")
            
            # Add separator line between metrics
            f.write("-" * metric_column_width)
            for _ in available_tasks:
                f.write("+" + "-" * task_column_width)
            f.write("\n")
        
        f.write("=" * (metric_column_width + len(available_tasks) * (task_column_width + 1)) + "\n")
        f.write("Note: Performance metrics vary by task (accuracy, F1, correlation, etc.)\n")
        f.write("Time values are in seconds, memory in GB, parameters in count or percentage\n")
        f.write("=" * (metric_column_width + len(available_tasks) * (task_column_width + 1)) + "\n")
        
        # Calculate overall GLUE score
        f.write("\n")
        f.write("OVERALL GLUE SCORE CALCULATION\n")
        f.write("--------------------------------\n")
        
        glue_scores = []
        
        for task, metric in PRIMARY_TASK_METRICS.items():
            # Find the mean value for this task/metric
            if task in metric_stats.get(metric, {}) and metric_stats[metric][task].get('mean') is not None:
                mean_value = metric_stats[metric][task]['mean']
                glue_scores.append((task, metric, mean_value))
                f.write(f"{task:<8} ({metric:<22}): {format_number(mean_value):>9}\n")
            else:
                f.write(f"{task:<8} ({metric:<22}): Missing data\n")
        
        if glue_scores:
            glue_avg = np.mean([score[2] for score in glue_scores])
            f.write("\n")
            f.write(f"OVERALL GLUE SCORE: {format_number(glue_avg):>9} (average across {len(glue_scores)} tasks)\n")
        else:
            f.write("\n")
            f.write("OVERALL GLUE SCORE: Could not calculate (insufficient data)\n")

if __name__ == "__main__":
    main()
