import os
import re
import matplotlib.pyplot as plt
import argparse

import numpy as np

def parse_log_file(log_file):
    """
    Parse log file and return raw metric lists.
    """
    rewards = []
    accuracies = []
    eflops = []
    
    if not os.path.exists(log_file):
        print(f"Warning: Log file {log_file} not found.")
        return [], [], []

    with open(log_file, 'r', encoding='utf-8') as f:
        content = f.read()
        
    # Regex for new English log format: "Reward: 1.7578 (Acc=1.00*w0.5, Ver=0.50*w0.5, eFLOPs=7.13e+13*w1.0)"
    pattern_en = re.compile(r"(?m)^Reward:\s*(-?\d+\.?\d*)\s*\(Acc=(-?\d+\.?\d*)\*.*eFLOPs=(-?\d+\.?\d*e\+\d+)\*")
    
    # Fallback for simple reward match
    pattern_simple = re.compile(r"(?m)^Reward:\s*(-?\d+\.?\d*)")
    
    matches = pattern_en.findall(content)
    
    if matches:
        for match in matches:
            try:
                rewards.append(float(match[0]))
                accuracies.append(float(match[1]))
                eflops.append(float(match[2]))
            except ValueError:
                continue
    else:
        print(f"Warning: Detailed format not found in {log_file}, falling back to simple reward parsing.")
        simple_matches = pattern_simple.findall(content)
        for match in simple_matches:
            try:
                rewards.append(float(match))
                accuracies.append(0.0)
                eflops.append(0.0)
            except ValueError:
                continue
            
    print(f"Found {len(rewards)} entries in {log_file}")
        
    return rewards, accuracies, eflops

def compute_cumulative(raw_data):
    """Compute cumulative sum of raw data."""
    return np.cumsum(raw_data).tolist() if raw_data else []

def plot_single_metric(results, metric_idx, title, ylabel, output_file):
    """
    Helper to plot a single metric.
    metric_idx: 0 for reward, 1 for accuracy, 2 for eFLOPs
    """
    plt.figure(figsize=(12, 8))
    
    has_data = False
    for label, metrics in results.items():
        data = metrics[metric_idx]
        if not data or all(x == 0 for x in data):
            # Skip empty or all-zero (fallback) data
            continue
            
        steps = range(1, len(data) + 1)
        plt.plot(steps, data, label=label, linewidth=2)
        has_data = True
    
    if not has_data:
        print(f"No data available to plot for {title}.")
        plt.close()
        return

    plt.xlabel("Steps (Queries)", fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    # Move legend inside to optimize space
    plt.legend(fontsize=10, loc='best')
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Adjust layout
    plt.tight_layout()
    
    plt.savefig(output_file, dpi=300)
    print(f"Plot saved to {output_file}")

def plot_final_bar_chart(results, metric_idx, title, ylabel, output_file, log_scale=False):
    """
    Plot a bar chart comparing the final cumulative values of all experiments.
    """
    plt.figure(figsize=(14, 8))
    
    labels = []
    values = []
    
    for label, metrics in results.items():
        data = metrics[metric_idx]
        # Same fix for bar chart: don't skip just because values are 0 (except maybe eFLOPs log scale)
        if not data:
            continue
            
        labels.append(label)
        values.append(data[-1]) # Take the last cumulative value
        
    if not labels:
        print(f"No data available for bar chart {title}.")
        plt.close()
        return

    # Sort by value for better readability
    sorted_indices = sorted(range(len(values)), key=lambda k: values[k])
    sorted_labels = [labels[i] for i in sorted_indices]
    sorted_values = [values[i] for i in sorted_indices]
    
    # Create bar chart
    bars = plt.barh(sorted_labels, sorted_values, color='skyblue', edgecolor='black')
    
    plt.xlabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(axis='x', linestyle='--', alpha=0.7)
    
    if log_scale:
        plt.xscale('log')
    
    # Add value labels to the end of bars
    for bar in bars:
        width = bar.get_width()
        label_x_pos = width * 1.02 if log_scale else width + (max(sorted_values) * 0.01)
        if log_scale and width <= 0: label_x_pos = 1e-10 # Handle 0/neg in log (shouldn't happen for eFLOPs)
        
        plt.text(label_x_pos, bar.get_y() + bar.get_height()/2, 
                 f'{width:.2e}' if log_scale or abs(width) > 1000 else f'{width:.2f}', 
                 va='center', fontsize=9)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300)
    print(f"Bar chart saved to {output_file}")

def compute_average_rate(raw_data):
    """Compute average rate (cumulative sum / step count)."""
    if not raw_data:
        return []
    cumsum = np.cumsum(raw_data)
    steps = np.arange(1, len(raw_data) + 1)
    return (cumsum / steps).tolist()

def plot_results(results, output_prefix="comparison"):
    """
    Plot cumulative reward, accuracy, and eFLOPs curves.
    """
    # Plot Cumulative Reward
    plot_single_metric(results, 0, "Cumulative Reward Comparison", "Cumulative Reward", f"{output_prefix}_reward.png")
    
    # Plot Cumulative Correct Count (was Cumulative Accuracy)
    plot_single_metric(results, 1, "Cumulative Correct Count Comparison", "Total Correct Answers", f"{output_prefix}_correct_count.png")
    
    # Plot Cumulative eFLOPs (Log Scale)
    plt.figure(figsize=(12, 8))
    has_data = False
    
    # Use distinct markers/linestyles for better visibility
    markers = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']
    linestyles = ['-', '--', '-.', ':']
    
    for i, (label, metrics) in enumerate(results.items()):
        data = metrics[2] # eFLOPs
        if not data or all(x == 0 for x in data):
            continue
        steps = range(1, len(data) + 1)
        
        # Cycle styles
        ms = markers[i % len(markers)]
        ls = linestyles[i % len(linestyles)]
        
        plt.plot(steps, data, label=label, linewidth=1.5, linestyle=ls, marker=ms, markevery=len(data)//10 + 1, alpha=0.8)
        has_data = True
    
    if has_data:
        plt.yscale('log')
        plt.xlabel("Steps (Queries)", fontsize=12)
        plt.ylabel("Cumulative eFLOPs (Log Scale)", fontsize=12)
        plt.title("Cumulative eFLOPs Comparison (Log Scale)", fontsize=14)
        plt.legend(fontsize=10, loc='best')
        plt.grid(True, linestyle='--', alpha=0.7, which="both") # which="both" for log scale grid
        plt.tight_layout()
        output_file_log = f"{output_prefix}_eflops_log.png"
        plt.savefig(output_file_log, dpi=300)
        print(f"Plot saved to {output_file_log}")
    else:
        plt.close()

    # --- New: Bar Charts for Final Comparison ---
    plot_final_bar_chart(results, 0, "Final Cumulative Reward Comparison", "Total Reward", f"{output_prefix}_reward_bar.png")
    plot_final_bar_chart(results, 1, "Final Cumulative Correct Count Comparison", "Total Correct Answers", f"{output_prefix}_correct_count_bar.png")
    
    if len(next(iter(results.values()))) > 3:
        plot_final_bar_chart(results, 3, "Final Average Accuracy Rate Comparison", "Average Accuracy", f"{output_prefix}_accuracy_rate_bar.png")
        
    plot_final_bar_chart(results, 2, "Final Cumulative eFLOPs Comparison", "Total eFLOPs (Log Scale)", f"{output_prefix}_eflops_bar.png", log_scale=True)


def main():
    parser = argparse.ArgumentParser(description='Plot cumulative metrics from existing log files')
    parser.add_argument('--output_prefix', type=str, default='comparison', help='Output image file prefix')
    parser.add_argument('--log_dir', type=str, default='./logs', help='Directory containing log files')
    parser.add_argument('--warmup_steps', type=int, default=30, help='Number of warmup steps to remove from plot')
    args = parser.parse_args()

    # Walk through log directory to find all .log files
    log_root_dir = args.log_dir
    results = {}
    results_no_warmup = {}
    
    # Mapping for semantic names
    # Now unused for existing files, but kept for reference or legacy files
    HASH_TO_NAME = {
        "ec395f6c": "Fixed 0.6B",
        "1b7601d3": "Fixed 32B",
        "4f878473": "Routing (NoTTS)",
        "6c3a9bce": "Routing (BoN)"
    }
    
    print(f"Scanning log directory: {log_root_dir}")
    for root, dirs, files in os.walk(log_root_dir):
        # Skip hidden directories like .ipynb_checkpoints
        dirs[:] = [d for d in dirs if not d.startswith('.')]
        
        for file in files:
            if file.endswith(".log") and file.startswith("bandit_process_"):
                full_path = os.path.join(root, file)
                
                # Extract algorithm and parameters from filename to create label
                # Filename format: bandit_process_{algo}_alpha{alpha}_beta{beta}_lambda{lambda}_{diag}_{fusion}.log
                # Example: bandit_process_lin_ucb_alpha0.01_beta1.0_lambda1.0_diag_average.log
                
                # Remove prefix and suffix
                params_part = file[len("bandit_process_"):-len(".log")]
                
                label_parts = []
                remaining_params = params_part
                
                # Check for Semantic Names in filename (e.g. Fixed_0.6B, Routing_NoTTS)
                known_semantics = ["Fixed_0.6B", "Fixed_32B", "Routing_NoTTS", "Routing_Full", "Mixed_0.6BFull_32BFixed"]
                
                semantic_found = False
                for semantic in known_semantics:
                    if semantic in params_part:
                        # Convert underscore back to space or keep as is?
                        # Let's match run_evaluation.py labels: "Fixed 0.6B", "Routing (NoTTS)"
                        # "Routing_NoTTS" -> "Routing (NoTTS)"
                        display_name = semantic
                        if semantic == "Fixed_0.6B": display_name = "Fixed 0.6B"
                        elif semantic == "Fixed_32B": display_name = "Fixed 32B"
                        elif semantic == "Routing_NoTTS": display_name = "Routing (NoTTS)"
                        elif semantic == "Routing_Full": display_name = "Routing (Full)"
                        elif semantic == "Mixed_0.6BFull_32BFixed": display_name = "Mixed (0.6B+32B)"
                        
                        label_parts.append(display_name)
                        
                        # Remove semantic part from params to parse the rest
                        # Assuming format: {Semantic}_lin_...
                        # or fix_{Semantic}_lin_... (if someone kept fix_)
                        # We just need to ensure we don't double parse
                        
                        # Split by semantic name to get the rest
                        parts = params_part.split(semantic)
                        if len(parts) > 1:
                            remaining_params = parts[1]
                            if remaining_params.startswith("_"):
                                remaining_params = remaining_params[1:]
                        
                        semantic_found = True
                        break
                
                # Legacy Hash Handling (if files were not renamed)
                if not semantic_found and params_part.startswith("fix_"):
                    parts = params_part.split('_')
                    if len(parts) > 1:
                        # parts[0] is 'fix', parts[1] is hash
                        hash_val = parts[1]
                        
                        if hash_val in HASH_TO_NAME:
                            label_parts.append(HASH_TO_NAME[hash_val])
                        else:
                            label_parts.append(f"fix_{hash_val}")
                
                # Identify Algorithm
                algo = "Unknown"
                if remaining_params.startswith("lin_ucb") or remaining_params.startswith("lin"):
                    algo = "LinUCB"
                elif remaining_params.startswith("neural_ucb"):
                    algo = "NeuralUCB"
                elif remaining_params.startswith("random"):
                    algo = "Random"
                elif remaining_params.startswith("fixed"):
                    algo = "Fixed"
                
                label_parts.append(algo)
                
                # Parse parameters (support both long "alpha1.0" and short "a1.0" formats)
                
                # Alpha
                # Matches: alpha1.0, _alpha1.0, a1.0, _a1.0
                alpha_match = re.search(r"(?:alpha|_a)([\d\.]+)", "_" + remaining_params)
                if alpha_match and algo == "LinUCB":
                    label_parts.append(f"alpha={alpha_match.group(1)}")
                
                # Beta
                beta_match = re.search(r"(?:beta|_b)([\d\.]+)", "_" + remaining_params)
                if beta_match and algo == "NeuralUCB":
                    label_parts.append(f"beta={beta_match.group(1)}")
                    
                # Fusion / Mode
                if "average" in remaining_params or "_ave" in remaining_params:
                    label_parts.append("avg")
                elif "concat" in remaining_params or "_con" in remaining_params:
                    label_parts.append("concat")
                    
                # Diag / Full
                if "diag" in remaining_params:
                    label_parts.append("diag")
                elif "full" in remaining_params or "_f_" in remaining_params or remaining_params.endswith("_f"):
                    label_parts.append("full")
                    
                # Fixed model name handling
                if algo == "Fixed":
                    # Try to extract what's after 'fixed_'
                    if "fixed_" in remaining_params:
                         label_parts.append(remaining_params.split("fixed_")[-1])
                    else:
                         label_parts.append(remaining_params)
                
                label = " ".join(label_parts)
                
                print(f"Processing {label}: {full_path}")
                rewards, accuracies, eflops = parse_log_file(full_path)
                
                # Compute full cumulative metrics
                cum_rewards = compute_cumulative(rewards)
                cum_accuracies = compute_cumulative(accuracies)
                cum_eflops = compute_cumulative(eflops)
                avg_accuracies = compute_average_rate(accuracies)
                
                if cum_rewards:
                    results[label] = (cum_rewards, cum_accuracies, cum_eflops, avg_accuracies)
                    print(f"  -> Loaded {len(cum_rewards)} data points.")
                    
                    # Handle warmup filtering if requested
                    if args.warmup_steps > 0 and len(rewards) > args.warmup_steps:
                        rewards_no_warmup = rewards[args.warmup_steps:]
                        accuracies_no_warmup = accuracies[args.warmup_steps:]
                        eflops_no_warmup = eflops[args.warmup_steps:]
                        
                        cum_rewards_nw = compute_cumulative(rewards_no_warmup)
                        cum_accuracies_nw = compute_cumulative(accuracies_no_warmup)
                        cum_eflops_nw = compute_cumulative(eflops_no_warmup)
                        avg_accuracies_nw = compute_average_rate(accuracies_no_warmup)
                        
                        results_no_warmup[label] = (cum_rewards_nw, cum_accuracies_nw, cum_eflops_nw, avg_accuracies_nw)

    plot_results(results, args.output_prefix)
    
    if args.warmup_steps > 0 and results_no_warmup:
        print(f"\nPlotting results with {args.warmup_steps} warmup steps removed...")
        plot_results(results_no_warmup, f"{args.output_prefix}_no_warmup")

if __name__ == "__main__":
    main()
