#!/usr/bin/env python3
"""
Focused script for generating circular network plots with logos.
"""

import json
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from pathlib import Path
from collections import defaultdict
import matplotlib.image as mpimg
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

def get_short_model_name(full_name: str) -> str:
    """Convert full model name to short display name."""
    name_mapping = {
        'openai/gpt-4.1': 'gpt-4.1',
        'openai/gpt-4.1-mini': 'gpt-4.1-mini', 
        'openai/gpt-5': 'gpt-5',
        'anthropic/claude-sonnet-4': 'claude-sonnet-4',
        'deepseek/deepseek-chat-v3-0324': 'deepseek-v3',
        'google/gemini-2.5-flash': 'gemini-2.5-flash',
        'moonshotai/kimi-k2': 'kimi-k2',
        'z-ai/glm-4.5': 'glm-4.5',
        'x-ai/grok-4': 'grok-4',
        'qwen/qwen3-235b-a22b-2507': 'qwen3-235b'
    }
    return name_mapping.get(full_name, full_name.split('/')[-1])

def create_circular_network_plot(predictions_file: str, threshold: float = 3.0, save_path: str = None):
    """Create a circular network plot with model logos."""
    
    if not Path(predictions_file).exists():
        print(f"Warning: File not found {predictions_file}")
        return None
    
    # Collect prediction data
    prediction_data = defaultdict(lambda: defaultdict(int))
    model_totals = defaultdict(int)
    all_models = set()
    
    try:
        with open(predictions_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    pred = json.loads(line)
                    
                    evaluator_model = pred.get('evaluator_model', 'Unknown')
                    predicted_model = pred.get('predicted_model', '')
                    
                    if evaluator_model and predicted_model:
                        prediction_data[evaluator_model][predicted_model] += 1
                        model_totals[evaluator_model] += 1
                        all_models.add(evaluator_model)
                        all_models.add(predicted_model)
    
    except Exception as e:
        print(f"Error reading {predictions_file}: {e}")
        return None
    
    if not prediction_data:
        print("No prediction data found")
        return None
    
    # Create directed graph
    G = nx.DiGraph()
    
    # Add all models as nodes
    models = sorted(list(all_models))
    for model in models:
        short_name = get_short_model_name(model)
        G.add_node(short_name)
    
    # Add edges for predictions above threshold
    edge_data = []
    for evaluator_model in models:
        evaluator_short = get_short_model_name(evaluator_model)
        if evaluator_model in prediction_data:
            total_predictions = model_totals[evaluator_model]
            
            for predicted_model in models:
                predicted_short = get_short_model_name(predicted_model)
                count = prediction_data[evaluator_model].get(predicted_model, 0)
                
                if total_predictions > 0:
                    percentage = (count / total_predictions) * 100
                    
                    # Add edge if above threshold
                    if percentage > threshold:
                        G.add_edge(evaluator_short, predicted_short, 
                                 weight=percentage, count=count, total=total_predictions)
                        edge_data.append({
                            'from': evaluator_short,
                            'to': predicted_short,
                            'percentage': percentage,
                            'count': count
                        })
    
    if G.number_of_edges() == 0:
        print(f"No prediction patterns above {threshold}% threshold found")
        return None
    
    # Create plot
    fig, ax = plt.subplots(figsize=(16, 16))
    
    # Create custom circular layout with GPT models grouped together
    import math
    
    nodes = list(G.nodes())
    n_nodes = len(nodes)
    
    # Define the desired order: GPT models together, then others
    gpt_models = ['gpt-4.1', 'gpt-4.1-mini', 'gpt-5']
    other_models = [node for node in nodes if node not in gpt_models]
    
    # Arrange nodes in circle: GPT models first, then others
    ordered_nodes = gpt_models + other_models
    
    # Create circular positions
    pos = {}
    radius = 1.0
    for i, node in enumerate(ordered_nodes):
        angle = 2 * math.pi * i / n_nodes
        x = radius * math.cos(angle)
        y = radius * math.sin(angle)
        pos[node] = (x, y)
    
    # Model logo file mapping
    model_logo_files = {
        'gpt-4.1': 'openai.png',
        'gpt-4.1-mini': 'openai.png',
        'gpt-5': 'openai.png',
        'claude-sonnet-4': 'claude.png',
        'deepseek-v3': 'deepseek.png',
        'qwen3-325b': 'qwen.png',
        'qwen3-235b': 'qwen.png',
        'glm-4.5': 'zhipu.png',
        'grok-4': 'x.png',
        'kimi-k2': 'kimi.png',
        'gemini-2.5-flash': 'gemini.png',
    }
    
    # Individual logo size mapping
    model_logo_sizes = {
        'gpt-4.1': 0.48,
        'gpt-4.1-mini': 0.48,
        'gpt-5': 0.48,
        'claude-sonnet-4': 0.24,
        'deepseek-v3': 0.12,
        'qwen3-325b': 0.045,
        'qwen3-235b': 0.045,
        'glm-4.5': 0.48,
        'grok-4': 0.06,
        'kimi-k2': 0.24,
        'gemini-2.5-flash': 0.24,
    }
    
    # Emoji fallbacks
    model_emoji_fallbacks = {
        'gpt-4.1': '🧠',
        'gpt-4.1-mini': '⚡',
        'claude-sonnet-4': '🎭',
        'deepseek-v3': '🌊',
        'qwen3-325b': '🏮',
        'glm-4.5': '⭐',
        'grok-4': '🚀',
        'kimi-k2': '🤖',
        'gemini-2.5-flash': '♊'
    }
    
    # Add logo images or fallback emojis for each node
    logos_dir = Path("results/logos")
    successful_logos = 0
    
    for node, (x, y) in pos.items():
        logo_filename = model_logo_files.get(node)
        logo_loaded = False
        
        if logo_filename:
            logo_path = logos_dir / logo_filename
            if logo_path.exists():
                try:
                    # Load and display logo image
                    img = mpimg.imread(logo_path)
                    
                    # Add transparency to the logo for better edge visibility
                    if img.shape[2] == 3:  # RGB image, add alpha channel
                        alpha = 0.8  # 80% opacity
                        img_with_alpha = np.zeros((img.shape[0], img.shape[1], 4))
                        img_with_alpha[:, :, :3] = img
                        img_with_alpha[:, :, 3] = alpha
                        img = img_with_alpha
                    elif img.shape[2] == 4:  # RGBA image, modify existing alpha
                        img[:, :, 3] = img[:, :, 3] * 0.8  # Reduce alpha to 80%
                    
                    # Use individual logo size for each model
                    logo_size = model_logo_sizes.get(node, 0.12)
                    imagebox = OffsetImage(img, zoom=logo_size)
                    ab = AnnotationBbox(imagebox, (x, y), frameon=False, pad=0)
                    ab.set_zorder(1)  # Set logo z-order to 1 (behind edges)
                    ax.add_artist(ab)
                    logo_loaded = True
                    successful_logos += 1
                except Exception as e:
                    print(f"Warning: Could not load logo {logo_path}: {e}")
        
        # Fallback to emoji if logo couldn't be loaded
        if not logo_loaded:
            # Draw white circle background for emoji
            circle = plt.Circle((x, y), 0.05, color='white', alpha=0.9, zorder=2)
            ax.add_patch(circle)
            
            # Add emoji
            emoji = model_emoji_fallbacks.get(node, '🤖')
            ax.text(x, y, emoji, fontsize=105, ha='center', va='center', 
                   weight='bold', zorder=3)
        
        # Add text label with custom positioning for each model
        models_above = []
        models_right = ['gpt-4.1', 'gpt-5', 'gpt-4.1-mini']
        models_left = ['claude-sonnet-4']
        models_below_far = ['glm-4.5', 'qwen3-235b', 'deepseek-v3']
        
        if node in models_above:
            text_x, text_y = x, y + 0.12
            ha_align, va_align = 'center', 'bottom'
        elif node in models_right:
            text_x, text_y = x + 0.12, y
            ha_align, va_align = 'left', 'center'
        elif node in models_left:
            text_x, text_y = x - 0.12, y
            ha_align, va_align = 'right', 'center'
        elif node in models_below_far:
            text_x, text_y = x, y - 0.12
            ha_align, va_align = 'center', 'top'
        else:
            text_x, text_y = x, y - 0.08
            ha_align, va_align = 'center', 'top'
        
        ax.text(text_x, text_y, node, fontsize=42, ha=ha_align, va=va_align, 
               weight='bold', alpha=0.9,
               bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.8))
    
    # Draw edges with variable arrow sizes
    edges = G.edges(data=True)
    if edges:
        edge_weights = [edge[2]['weight'] for edge in edges]
        max_weight = max(edge_weights) if edge_weights else 1
        
        # Draw edges individually with weight-based arrow sizing
        for (u, v, data) in edges:
            weight = data['weight']
            normalized_weight = weight / max_weight
            
            # Calculate edge width and arrow size
            edge_width = normalized_weight * 12 + 3  # Doubled width
            arrow_size = 60 - (normalized_weight * 15)  # 3x bigger arrows
            
            # Draw individual edge
            edge_collection = nx.draw_networkx_edges(G, pos,
                                  edgelist=[(u, v)],
                                  edge_color='#2E3440',  # Dark blue-gray
                                  arrows=True,
                                  arrowsize=arrow_size,
                                  arrowstyle='-|>',
                                  width=edge_width,
                                  alpha=0.8,
                                  connectionstyle="arc3,rad=0.1",
                                  ax=ax)
            
            # Set z-order for each edge collection
            if edge_collection:
                if isinstance(edge_collection, list):
                    for collection in edge_collection:
                        collection.set_zorder(2)
                else:
                    edge_collection.set_zorder(2)
    
    # Clean formatting
    ax.axis('off')
    ax.set_xlim(-1.3, 1.3)
    ax.set_ylim(-1.3, 1.3)
    plt.tight_layout()
    
    # Print statistics
    print(f"\n📊 Network Statistics:")
    print(f"   • Nodes (models): {G.number_of_nodes()}")
    print(f"   • Edges (prediction patterns >{threshold}%): {G.number_of_edges()}")
    print(f"   • Real logos loaded: {successful_logos}/{len(pos)}")
    print(f"   • Emoji fallbacks: {len(pos) - successful_logos}")
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nNetwork graph saved to {save_path}")
    
    return fig

def main():
    """Generate network plots for both corpora."""
    
    # 100-word corpus
    predictions_100 = "results/predictions/predictions_exact_model.jsonl"
    output_100 = "results/plot_100/prediction_network_circular_new.pdf"
    
    if Path(predictions_100).exists():
        print("Generating 100-word corpus network plot...")
        create_circular_network_plot(predictions_100, threshold=3.0, save_path=output_100)
    
    # 500-word corpus  
    predictions_500 = "results/predictions_500/predictions_exact_model.jsonl"
    output_500 = "results/plot_500/prediction_network_circular_new.pdf"
    
    if Path(predictions_500).exists():
        print("\nGenerating 500-word corpus network plot...")
        create_circular_network_plot(predictions_500, threshold=3.0, save_path=output_500)
    
    print("\n✅ Network plot generation complete!")

if __name__ == "__main__":
    main()