#!/usr/bin/env python3
"""
Plot SleepEDF time series samples from sleep_cot_data.csv.
Each sample is plotted as a PNG with EEG data and the full_prediction as text.
"""
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

CSV_PATH = "sleep_cot_data.csv"
OUTPUT_DIR = "sleep_cot_plots"

# Publication style
plt.style.use('seaborn-v0_8')
sns.set_palette("colorblind")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

display_label_map = {
    'W': 'Wake',
    'N1': 'Non-REM stage 1',
    'N2': 'Non-REM stage 2',
    'N3': 'Non-REM stage 3',
    'N4': 'Non-REM stage 4',
    'REM': 'REM sleep',
    'M': 'Movement',
    'Unknown': 'Unknown'
}

def plot_sample(row, idx):
    eeg_data = np.array(json.loads(row['eeg_data']))
    full_pred = row['full_prediction']
    gt_label = row['ground_truth_label']
    pred_label = row['predicted_label']
    sample_idx = row['sample_index']
    series_length = row['series_length']

    # Map labels to pretty names
    pretty_gt = display_label_map.get(gt_label, gt_label)
    pretty_pred = display_label_map.get(pred_label, pred_label)

    # Normalize text length to exactly 800 characters
    text_length = 900
    if len(full_pred) < text_length:
        # Pad with whitespace if shorter
        full_pred = full_pred + " " * (text_length - len(full_pred))
    elif len(full_pred) > text_length:
        # Truncate if longer
        full_pred = full_pred[:text_length]
    
    # Add extra newlines to ensure consistent text box height
    full_pred = full_pred + "\n"

    # Normalize EEG data for plotting
    mean = np.mean(eeg_data)
    std = np.std(eeg_data) if np.std(eeg_data) > 0 else 1.0
    eeg_plot = (eeg_data - mean) / std

    fig, ax1 = plt.subplots(figsize=(12, 7))
    t = np.arange(len(eeg_plot))
    # Use the same blue color as PAMAP2 plots (first color from colorblind palette)
    ax1.plot(t, eeg_plot, linewidth=2.5, color='#0173B2', alpha=0.8, label='EEG')
    ax1.set_xlabel('Time Step', fontsize=26)
    ax1.set_ylabel('Normalized EEG Amplitude', fontsize=26)
    ax1.set_title(f"Sample {sample_idx} | GT: {pretty_gt} | Pred: {pretty_pred}", fontsize=22, fontweight='bold')
    ax1.legend(fontsize=13, loc='upper right')
    ax1.grid(True, alpha=0.3)
    ax1.tick_params(axis='both', which='major', labelsize=26)
    ax1.set_ylim(-3, 3)
    ax1.set_yticks(np.linspace(-3, 3, 7))

    # Add full_prediction as a text box below the plot (same as PAMAP2)
    plt.gcf().text(0.01, -0.02, f"Prediction:\n{full_pred}", fontsize=30, ha='left', va='top', wrap=True,
                   bbox=dict(boxstyle='round', facecolor='whitesmoke', alpha=0.9, edgecolor='gray'))

    plt.tight_layout(rect=[0, 0.05, 1, 1])

    fname = f"sample_{idx+1:03d}_gt_{pretty_gt.lower().replace(' ', '_').replace('-', '_')}.png"
    plt.savefig(os.path.join(OUTPUT_DIR, fname), dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved {fname}")

def main():
    df = pd.read_csv(CSV_PATH)
    print(f"Loaded {len(df)} samples from {CSV_PATH}")
    for idx, row in df.iterrows():
        plot_sample(row, idx)
    print(f"All plots saved to {OUTPUT_DIR}/")

if __name__ == "__main__":
    main() 