import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import argparse
import os
from pathlib import Path

def load_detailed_logs(detailed_logs_dir, dataset, strategy, seed):
    """Load detailed logs for a specific run."""
    base_fname = f"{dataset}_{strategy}_seed{seed}"
    
    # Load adacoma logs (we'll use these for visualization)
    adacoma_ada_file = os.path.join(detailed_logs_dir, f"adacoma_adahedge_{base_fname}.csv")
    adacoma_hedge_file = os.path.join(detailed_logs_dir, f"adacoma_hedge_{base_fname}.csv")
    
    if not os.path.exists(adacoma_ada_file) or not os.path.exists(adacoma_hedge_file):
        raise FileNotFoundError(f"Could not find detailed logs for {base_fname}")
    
    df_ada = pd.read_csv(adacoma_ada_file)
    df_hedge = pd.read_csv(adacoma_hedge_file)
    
    return df_ada, df_hedge

def plot_forecaster_lengths(df, title_prefix, output_dir):
    """Plot individual forecaster lengths over time."""
    # Get all forecaster length columns
    len_cols = [col for col in df.columns if col.startswith('fc_') and col.endswith('_len')]
    K = len(len_cols)
    
    plt.figure(figsize=(12, 6))
    for i in range(K):
        col = f'fc_{i}_len'
        plt.plot(df['time_step'], df[col], label=f'Forecaster {i+1}', alpha=0.7)
    
    plt.xlabel('Time Step')
    plt.ylabel('Interval Length')
    plt.title(f'{title_prefix} - Individual Forecaster Lengths')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save plot
    output_file = os.path.join(output_dir, f'{title_prefix.lower()}_forecaster_lengths.png')
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    return output_file

def plot_weights_evolution(df_ada, df_hedge, title_prefix, output_dir):
    """Plot AdaHedge and Hedge weights evolution over time."""
    # Get all weight columns
    ada_weight_cols = [col for col in df_ada.columns if col.startswith('adacoma_ada_prior_w_')]
    hedge_weight_cols = [col for col in df_hedge.columns if col.startswith('adacoma_hedge_prior_w_')]
    K = len(ada_weight_cols)
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), sharex=True)
    
    # Plot AdaHedge weights
    for i in range(K):
        col = f'adacoma_ada_prior_w_{i}'
        ax1.plot(df_ada['time_step'], df_ada[col], label=f'Forecaster {i+1}', alpha=0.7)
    
    ax1.set_ylabel('Weight')
    ax1.set_title(f'{title_prefix} - AdaHedge Weights Evolution')
    ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax1.grid(True, alpha=0.3)
    
    # Plot Hedge weights
    for i in range(K):
        col = f'adacoma_hedge_prior_w_{i}'
        ax2.plot(df_hedge['time_step'], df_hedge[col], label=f'Forecaster {i+1}', alpha=0.7)
    
    ax2.set_xlabel('Time Step')
    ax2.set_ylabel('Weight')
    ax2.set_title(f'{title_prefix} - Hedge Weights Evolution')
    ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    output_file = os.path.join(output_dir, f'{title_prefix.lower()}_weights_evolution.png')
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    return output_file

def visualize_run(detailed_logs_dir, dataset, strategy, seed, output_dir):
    """Create visualizations for a specific run."""
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Load data
    df_ada, df_hedge = load_detailed_logs(detailed_logs_dir, dataset, strategy, seed)
    
    # Create title prefix
    title_prefix = f"{dataset.upper()} - {strategy.replace('_', ' ').title()} (Seed {seed})"
    
    # Generate plots
    lengths_file = plot_forecaster_lengths(df_ada, title_prefix, output_dir)
    weights_file = plot_weights_evolution(df_ada, df_hedge, title_prefix, output_dir)
    
    print(f"Visualizations saved:")
    print(f"- Forecaster lengths: {lengths_file}")
    print(f"- Weights evolution: {weights_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize detailed experiment logs")
    parser.add_argument("--dataset", type=str, required=True, choices=["elec", "aram"])
    parser.add_argument("--strategy", type=str, required=True, 
                       choices=["random_switch"])
    parser.add_argument("--seed", type=int, required=True)
    parser.add_argument("--detailed_logs_dir", type=str, 
                       default="experiment_results/detailed_logs")
    parser.add_argument("--output_dir", type=str, 
                       default="experiment_results/visualizations")
    
    args = parser.parse_args()
    
    visualize_run(
        args.detailed_logs_dir,
        args.dataset,
        args.strategy,
        args.seed,
        args.output_dir
    ) 