import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def visualize_batch_results(summary_path):
    with open(summary_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # Filter only completed data (with metrics)
    data = [d for d in data if "hit_rate" in d]
    
    if not data:
        print("No valid metric data to visualize.")
        return

    # Extract metrics
    names = [d["name"][:30] + "..." if len(d["name"]) > 30 else d["name"] for d in data]
    hit_rates = [d.get("hit_rate", 0) for d in data]
    yields = [d.get("yield", 0) for d in data]
    novelty = [d.get("novelty", 0) for d in data]
    similarities = [1.0 - n for n in novelty]
    temporal_lead = [d.get("temporal_lead", 0) for d in data]
    cross_domain = [d.get("cross_domain", 0) for d in data]

    output_dir = Path(__file__).parent.parent / "output"
    output_dir.mkdir(exist_ok=True)
    plt.style.use('seaborn-v0_8-muted')

    # --- 1. Performance Overview (Bars) ---
    fig1, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(14, 18), constrained_layout=True)
    y_pos = np.arange(len(names))
    
    # Hit Rate
    ax1.barh(y_pos, hit_rates, color='skyblue', edgecolor='navy', alpha=0.8)
    ax1.set_yticks(y_pos)
    ax1.set_yticklabels(names)
    ax1.invert_yaxis()
    ax1.set_title('Predictive Hit Rate (%)', fontweight='bold', fontsize=14)
    ax1.grid(axis='x', linestyle='--', alpha=0.7)
    for i, v in enumerate(hit_rates):
        ax1.text(v + 0.5, i, f"{v}%", va='center', fontweight='bold')

    # Yield
    ax2.barh(y_pos, yields, color='lightgreen', edgecolor='darkgreen', alpha=0.8)
    ax2.set_yticks(y_pos)
    ax2.set_yticklabels(names)
    ax2.invert_yaxis()
    ax2.set_title('Hypothesis Yield (Count per Topic)', fontweight='bold', fontsize=14)
    ax2.grid(axis='x', linestyle='--', alpha=0.7)
    for i, v in enumerate(yields):
        ax2.text(v + 0.1, i, str(v), va='center', fontweight='bold')

    # Similarity (1 - Novelty)
    ax3.barh(y_pos, similarities, color='salmon', edgecolor='darkred', alpha=0.8)
    ax3.set_yticks(y_pos)
    ax3.set_yticklabels(names)
    ax3.invert_yaxis()
    ax3.set_title('Average Cosine Similarity (Scientific Fit)', fontweight='bold', fontsize=14)
    ax3.grid(axis='x', linestyle='--', alpha=0.7)
    for i, v in enumerate(similarities):
        ax3.text(v + 0.01, i, f"{v:.4f}", va='center', fontweight='bold')

    fig1.savefig(output_dir / "batch_performance_bars.png", dpi=300)
    print(f"📊 Performance bars saved to: {output_dir / 'batch_performance_bars.png'}")

    # --- 2. Novelty Analysis (Scatter) ---
    fig2 = plt.figure(figsize=(12, 10), constrained_layout=True)
    ax_scatter = fig2.add_subplot(1, 1, 1)
    scatter = ax_scatter.scatter(novelty, temporal_lead, s=150, c=hit_rates, cmap='viridis', edgecolors='white', alpha=0.9)
    ax_scatter.set_xlabel('Novelty Score (Semantic Distance)')
    ax_scatter.set_ylabel('Temporal Lead (Days)')
    ax_scatter.set_title('Novelty vs Forecasting Lead Distribution', fontweight='bold', fontsize=16)
    cbar = plt.colorbar(scatter, ax=ax_scatter)
    cbar.set_label('Hit Rate (%)')
    ax_scatter.grid(True, linestyle=':', alpha=0.6)
    
    # Label top points
    for i, txt in enumerate(names):
        if hit_rates[i] > 80 or novelty[i] > np.mean(novelty):
            ax_scatter.annotate(txt, (novelty[i], temporal_lead[i]), xytext=(5, 5), textcoords='offset points', fontsize=8, alpha=0.7)

    fig2.savefig(output_dir / "batch_novelty_analysis.png", dpi=300)
    print(f"📊 Novelty analysis saved to: {output_dir / 'batch_novelty_analysis.png'}")

    # --- 3. Global Radar & Summary ---
    fig3 = plt.figure(figsize=(16, 10), constrained_layout=True)
    gs = fig3.add_gridspec(1, 2)
    
    # Radar
    ax_radar = fig3.add_subplot(gs[0, 0], polar=True)
    categories = ['Hit Rate (%)', 'Novelty (x100)', 'Temporal Lead (d/10)', 'Cross-Domain (x10)']
    avg_hit = np.mean(hit_rates)
    avg_novelty = np.mean(novelty) * 100
    avg_lead = np.mean(temporal_lead) / 10
    avg_cd = np.mean(cross_domain) * 10
    
    values = [avg_hit, avg_novelty, avg_lead, avg_cd]
    num_vars = len(categories)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    values += values[:1]
    angles += angles[:1]
    
    ax_radar.fill(angles, values, color='teal', alpha=0.25)
    ax_radar.plot(angles, values, color='teal', linewidth=2)
    ax_radar.set_xticks(angles[:-1])
    ax_radar.set_xticklabels(categories)
    ax_radar.set_title('Global Average Performance Radar', fontweight='bold', fontsize=14, pad=20)

    # Summary Table
    ax_table = fig3.add_subplot(gs[0, 1])
    ax_table.axis('off')
    summary_text = [
        ["Metric", "Global Average"],
        ["Completed Topics", f"{len(data)} / 20"],
        ["Avg Hit Rate", f"{np.mean(hit_rates):.1f}%"],
        ["Avg Temporal Lead", f"{np.mean(temporal_lead):.0f} days"],
        ["Avg Novelty", f"{np.mean(novelty):.4f}"],
        ["Avg Hypothesis Yield", f"{np.mean(yields):.1f}"],
        ["Best Performing Topic", names[np.argmax(hit_rates)]]
    ]
    table = ax_table.table(cellText=summary_text, loc='center', cellLoc='left', colWidths=[0.4, 0.6])
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.2, 2.5)
    ax_table.set_title('Batch Execution Statistics Summary', fontweight='bold', fontsize=14)

    fig3.savefig(output_dir / "batch_summary_radar.png", dpi=300)
    print(f"📊 Summary radar/table saved to: {output_dir / 'batch_summary_radar.png'}")

if __name__ == "__main__":
    import sys
    base_dir = Path(__file__).parent.parent.absolute()
    
    # Try current batch runs first, then fallback to user's batch_vault
    batch_summary = base_dir / "batch_runs" / "batch_summary.json"
    
    if len(sys.argv) > 1:
        batch_summary = Path(sys.argv[1])
    elif not batch_summary.exists():
        # Fallback to the known batch_vault location found during research
        vault_path = base_dir / "batch_vault"
        if vault_path.exists():
            batches = sorted([d for d in vault_path.iterdir() if d.is_dir()])
            if batches:
                batch_summary = batches[-1] / "batch_summary.json"
        
    if batch_summary.exists():
        print(f"📂 Loading summary from: {batch_summary}")
        visualize_batch_results(batch_summary)
    else:
        print(f"Error: Summary file {batch_summary} not found.")
