#!/usr/bin/env python3

import json
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
import glob
import re
from scipy.interpolate import interp1d
from scipy.ndimage import gaussian_filter1d
from utils import moving_average

def get_nested_value(entry, path):
    """Helper function to get nested dictionary values using dot notation"""
    value = entry
    for key in path.split("."):
        if value is None or key not in value:
            return np.nan
        value = value[key]
    
    # Handle special string indicators like "NaN (no active examples)"
    if isinstance(value, str) and "NaN" in value:
        return np.nan
        
    return value

def gp_smooth_with_confidence(x, y, sigma=30.0, confidence_scale=0.05):
    """Apply Gaussian process-like smoothing with confidence bands"""
    # Apply Gaussian smoothing to the main line
    smoothed_y = gaussian_filter1d(y, sigma=sigma)
    
    # Calculate local variance for confidence bands
    # Use a sliding window to estimate local noise
    window_size = min(50, len(y) // 10)
    local_std = np.zeros_like(y)
    
    for i in range(len(y)):
        start_idx = max(0, i - window_size // 2)
        end_idx = min(len(y), i + window_size // 2)
        window_data = y[start_idx:end_idx]
        window_smooth = smoothed_y[start_idx:end_idx]
        
        # Calculate residuals in the window
        residuals = window_data - window_smooth
        local_std[i] = np.std(residuals) if len(residuals) > 1 else confidence_scale
    
    # Smooth the standard deviation as well
    smoothed_std = gaussian_filter1d(local_std, sigma=sigma//2)
    
    # Scale the confidence bands
    confidence_bands = smoothed_std * confidence_scale * 50  # Adjust multiplier for visibility
    
    return smoothed_y, confidence_bands

def load_and_plot_experiment(file_path, label, color, window_size=30, gaussian_sigma=30.0):
    """Load experiment data and return x, y data for plotting"""
    print(f"Loading {file_path}...")
    
    with open(file_path, "r") as f:
        file_contents = f.readlines()
        hyperparameters = json.loads(file_contents[0].strip())
        entries = [json.loads(line) for line in file_contents[1:]]
    
    # Extract model type for legend
    model_type = hyperparameters.get('model_type', 'unknown')
    
    # Extract normalized reward data
    raw_data = [
        get_nested_value(entry, "Training Metrics.Normalized Reward")
        for entry in entries
    ]
    
    # Filter out None values and convert to float array
    valid_data = []
    for d in raw_data:
        if d is None:
            valid_data.append(np.nan)
        else:
            try:
                valid_data.append(float(d))
            except (ValueError, TypeError):
                valid_data.append(np.nan)
    
    # Only proceed if we have valid data
    if valid_data and not all(np.isnan(d) for d in valid_data):
        # Convert to numpy array for handling NaN values
        data_array = np.array(valid_data, dtype=float)
        
        # First apply moving average for initial smoothing
        initially_smoothed = moving_average(data_array, window_size)
        
        # Create x-coordinates for smoothed data
        offset = (window_size - 1) // 2 if window_size > 1 else 0
        x_coords = np.arange(offset, offset + len(initially_smoothed))
        
        # Filter out NaN values
        mask = ~np.isnan(initially_smoothed)
        if np.any(mask):
            clean_x = x_coords[mask]
            clean_y = initially_smoothed[mask]
            
            # Apply Gaussian process-like smoothing with confidence bands
            if len(clean_y) > 50:  # Only apply GP smoothing if we have enough points
                gp_smooth, confidence = gp_smooth_with_confidence(clean_x, clean_y, sigma=gaussian_sigma)
                return clean_x, gp_smooth, confidence, model_type
            else:
                # Fallback for short series
                return clean_x, clean_y, np.ones_like(clean_y) * 0.01, model_type
    
    return None, None, None, None

def find_latest_experiment(task_type, machine_pattern):
    """Find the most recent experiment for a given machine pattern."""
    
    # Look for directories matching the pattern
    pattern = f"results/{task_type}/*_{machine_pattern}/log.jsonl"
    matching_files = glob.glob(pattern)
    
    if not matching_files:
        # Try without machine suffix (for exact matches)
        pattern = f"results/{task_type}/*{machine_pattern}*/log.jsonl"
        matching_files = glob.glob(pattern)
    
    if not matching_files:
        return None
        
    # Extract timestamps and sort by them
    timestamped_files = []
    for file_path in matching_files:
        dir_name = os.path.basename(os.path.dirname(file_path))
        # Extract timestamp from directory name (YYYYMMDD_HHMMSS format)
        timestamp_match = re.match(r'^(\d{8}_\d{6})', dir_name)
        if timestamp_match:
            timestamp = timestamp_match.group(1)
            timestamped_files.append((timestamp, file_path))
    
    if timestamped_files:
        # Sort by timestamp (newest first) and return the most recent
        timestamped_files.sort(reverse=True)
        return timestamped_files[0][1]
    
    return None

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Create Gaussian process-style smoothed plot of normalized rewards")
    parser.add_argument("--window_size", type=int, default=30, help="Moving average window size (default: 30)")
    parser.add_argument("--gaussian_sigma", type=float, default=30.0, help="Gaussian smoothing sigma (default: 30.0)")
    parser.add_argument("--output", type=str, default="results/figures/combined_normalized_reward_gp_smoothed.png", help="Output filename")
    parser.add_argument("--task_type", type=str, default="wiki_continuation", help="Task type to plot (default: wiki_continuation)")
    parser.add_argument("--use_max_x", action="store_true", help="Use maximum x extent instead of minimum (default: False)")
    parser.add_argument("--font_size", type=int, default=18, help="Font size for all text elements in the plot (default: 18)")
    args = parser.parse_args()
    
    # Define the machine patterns and their labels/colors
    machine_configs = [
        ("left", "left1", "#1f77b4"),
        ("mid2", "mid2", "#ff7f0e"), 
        ("right2", "right2", "#2ca02c"),
        ("riight2", "riight2", "#d62728")
    ]
    
    # Find the latest experiments for each machine
    experiments = []
    for machine_pattern, label, color in machine_configs:
        latest_file = find_latest_experiment(args.task_type, machine_pattern)
        if latest_file:
            experiments.append((latest_file, label, color))
            print(f"Found latest {machine_pattern}: {latest_file}")
        else:
            print(f"Warning: No experiments found for {machine_pattern}")
    
    if not experiments:
        print("Error: No experiments found for any of the target machines")
        return
    
    window_size = args.window_size
    gaussian_sigma = args.gaussian_sigma
    font_size = args.font_size
    
    # Create the plot
    plt.figure(figsize=(12, 8))
    
    max_x = 0
    min_x = float('inf')
    valid_experiments = []
    
    # Process each experiment
    for file_path, label, color in experiments:
        if os.path.exists(file_path):
            x_coords, y_smooth, confidence, model_type = load_and_plot_experiment(file_path, label, color, window_size, gaussian_sigma)
            if x_coords is not None and y_smooth is not None:
                # Plot the main smoothed line
                plt.plot(x_coords, y_smooth, label=model_type, color=color, linewidth=3, alpha=0.9)
                
                # Plot confidence bands (Gaussian process-like blur)
                plt.fill_between(x_coords, 
                               y_smooth - confidence, 
                               y_smooth + confidence,
                               color=color, alpha=0.15, linewidth=0)
                
                # Add a slightly thicker semi-transparent line for more blur effect
                plt.plot(x_coords, y_smooth, color=color, linewidth=6, alpha=0.2)
                
                max_x = max(max_x, np.max(x_coords))
                min_x = min(min_x, np.max(x_coords))
                valid_experiments.append((model_type, label))
                print(f"Successfully plotted {model_type} ({label}) with {len(x_coords)} points")
            else:
                print(f"Warning: No valid data found for {label}")
        else:
            print(f"Warning: File not found: {file_path}")
    
    # Set up the plot formatting similar to plot_training_metrics.py
    plt.xlabel("Training Batch No. []", fontsize=font_size)
    plt.ylabel("ln π(ans|cot) - ln π(ans|cot') []", fontsize=font_size)
    plt.title("Normalized Reward Comparison (GP-Style Smoothing)", fontsize=font_size + 2)
    plt.grid(True, linestyle='--', alpha=0.3, color='gray')
    plt.legend(fontsize=font_size - 2, framealpha=0.9)
    
    # Set x-axis limit based on user preference (default: minimum extent)
    if valid_experiments:
        if args.use_max_x:
            x_limit = max_x
        else:
            x_limit = min_x if min_x != float('inf') else max_x
        
        plt.xlim(0, x_limit)
    
    # Add smoothing info
    plt.text(
        0.95, 0.05,
        f"Moving avg window = {window_size}\nGaussian sigma = {gaussian_sigma:.1f}",
        transform=plt.gca().transAxes,
        horizontalalignment='right',
        verticalalignment='bottom',
        fontsize=font_size - 6,
        bbox=dict(
            facecolor='white',
            alpha=0.9,
            edgecolor='black',
            pad=5,
            boxstyle='round,pad=0.5'
        )
    )
    
    # Improve overall aesthetics
    plt.tick_params(axis='both', which='major', labelsize=font_size - 2)
    
    # Tight layout and save
    plt.tight_layout()
    output_file = args.output
    
    # Ensure the output directory exists
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"\nPlot saved to {output_file}")
    
    # Print experiment info
    print(f"\nSuccessfully plotted {len(valid_experiments)} experiments:")
    for model_type, label in valid_experiments:
        print(f"  {model_type} ({label})")

if __name__ == "__main__":
    main()