#!/usr/bin/env python3
"""
Clean, configurable network circular plot generator.
"""

import json
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from pathlib import Path
from collections import defaultdict
import matplotlib.image as mpimg
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from matplotlib.patches import FancyArrowPatch
import math

from network_config import *

def get_short_model_name(full_name: str) -> str:
    """Convert full model name to short display name."""
    return NAME_MAPPING.get(full_name, full_name.split('/')[-1])

def load_prediction_data(predictions_file: str):
    """Load and process prediction data from JSONL file."""
    prediction_data = defaultdict(lambda: defaultdict(int))
    model_totals = defaultdict(int)
    all_models = set()
    
    try:
        with open(predictions_file, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    pred = json.loads(line)
                    
                    evaluator_model = pred.get('evaluator_model', 'Unknown')
                    predicted_model = pred.get('predicted_model', '')
                    
                    if evaluator_model and predicted_model:
                        prediction_data[evaluator_model][predicted_model] += 1
                        model_totals[evaluator_model] += 1
                        all_models.add(evaluator_model)
                        all_models.add(predicted_model)
    
    except Exception as e:
        print(f"Error reading {predictions_file}: {e}")
        return None, None, None
    
    return prediction_data, model_totals, all_models

def create_graph_from_data(prediction_data, model_totals, all_models, threshold):
    """Create NetworkX graph from prediction data."""
    G = nx.DiGraph()
    
    # Add all models as nodes
    models = sorted(list(all_models))
    for model in models:
        short_name = get_short_model_name(model)
        G.add_node(short_name)
    
    # Add edges for predictions above threshold
    edge_data = []
    for evaluator_model in models:
        evaluator_short = get_short_model_name(evaluator_model)
        if evaluator_model in prediction_data:
            total_predictions = model_totals[evaluator_model]
            
            for predicted_model in models:
                predicted_short = get_short_model_name(predicted_model)
                count = prediction_data[evaluator_model].get(predicted_model, 0)
                
                if total_predictions > 0:
                    percentage = (count / total_predictions) * 100
                    
                    # Add edge if above threshold
                    if percentage > threshold:
                        G.add_edge(evaluator_short, predicted_short, 
                                 weight=percentage, count=count, total=total_predictions)
                        edge_data.append({
                            'from': evaluator_short,
                            'to': predicted_short,
                            'percentage': percentage,
                            'count': count
                        })
    
    return G, edge_data

def create_circular_layout(G):
    """Create circular layout with custom node ordering."""
    nodes = list(G.nodes())
    n_nodes = len(nodes)
    
    # Define node order: GPT models together, then others
    # Only include GPT models that are actually in the graph
    gpt_in_graph = [node for node in GPT_MODELS if node in nodes]
    other_models = [node for node in nodes if node not in GPT_MODELS]
    ordered_nodes = gpt_in_graph + other_models
    
    # Create circular positions - ensure we use all actual nodes
    pos = {}
    for i, node in enumerate(ordered_nodes):
        angle = 2 * math.pi * i / n_nodes
        x = CIRCLE_RADIUS * math.cos(angle)
        y = CIRCLE_RADIUS * math.sin(angle)
        pos[node] = (x, y)
    
    return pos

def add_logo_or_emoji(ax, node, x, y):
    """Add logo image or emoji fallback for a node."""
    logo_filename = MODEL_LOGO_FILES.get(node)
    logo_loaded = False
    
    if logo_filename:
        logo_path = Path(LOGOS_DIR) / logo_filename
        if logo_path.exists():
            try:
                # Load and display logo image
                img = mpimg.imread(logo_path)
                
                # Add transparency to the logo
                if img.shape[2] == 3:  # RGB image, add alpha channel
                    img_with_alpha = np.zeros((img.shape[0], img.shape[1], 4))
                    img_with_alpha[:, :, :3] = img
                    img_with_alpha[:, :, 3] = LOGO_ALPHA
                    img = img_with_alpha
                elif img.shape[2] == 4:  # RGBA image, modify existing alpha
                    img[:, :, 3] = img[:, :, 3] * LOGO_ALPHA
                
                # Use individual logo size for each model
                logo_size = MODEL_LOGO_SIZES.get(node, DEFAULT_LOGO_SIZE)
                imagebox = OffsetImage(img, zoom=logo_size)
                ab = AnnotationBbox(imagebox, (x, y), frameon=False, pad=0)
                ab.set_zorder(LOGO_Z_ORDER)
                ax.add_artist(ab)
                logo_loaded = True
            except Exception as e:
                print(f"Warning: Could not load logo {logo_path}: {e}")
    
    # Fallback to emoji if logo couldn't be loaded
    if not logo_loaded:
        # Draw white circle background for emoji
        circle = plt.Circle((x, y), EMOJI_CIRCLE_RADIUS, 
                          color=EMOJI_CIRCLE_COLOR, alpha=EMOJI_CIRCLE_ALPHA, 
                          zorder=EMOJI_CIRCLE_Z_ORDER)
        ax.add_patch(circle)
        
        # Add emoji
        emoji = MODEL_EMOJI_FALLBACKS.get(node, '🤖')
        ax.text(x, y, emoji, fontsize=EMOJI_FONT_SIZE, ha='center', va='center', 
               weight='bold', zorder=EMOJI_Z_ORDER)
    
    return logo_loaded

def add_text_label(ax, node, x, y):
    """Add text label for a node with custom positioning."""
    if node in MODELS_ABOVE:
        text_x, text_y = x, y + LABEL_OFFSET_FAR
        ha_align, va_align = 'center', 'bottom'
    elif node in MODELS_RIGHT:
        # Move GPT names further to the right
        text_x, text_y = x + LABEL_OFFSET_RIGHT_EXTRA, y
        ha_align, va_align = 'left', 'center'
    elif node in MODELS_LEFT:
        text_x, text_y = x - LABEL_OFFSET_FAR, y
        ha_align, va_align = 'right', 'center'
    elif node in MODELS_RIGHT_DOWN:
        # Grok: move right and down
        text_x, text_y = x + LABEL_OFFSET_RIGHT_DOWN, y - LABEL_OFFSET_RIGHT_DOWN
        ha_align, va_align = 'left', 'center'
    elif node in MODELS_BELOW_FAR:
        text_x, text_y = x, y - LABEL_OFFSET_FAR
        ha_align, va_align = 'center', 'top'
    else:
        text_x, text_y = x, y - LABEL_OFFSET_STANDARD
        ha_align, va_align = 'center', 'top'
    
    # Remove bbox parameter to eliminate black boxes
    ax.text(text_x, text_y, node, fontsize=LABEL_FONT_SIZE, 
           ha=ha_align, va=va_align, weight=LABEL_WEIGHT, alpha=LABEL_ALPHA)

def get_edge_color(normalized_weight):
    """Get color based on weight with dramatic gradual transitions (light gray to bright blue)."""
    # Use much lighter gray and more dramatic transitions
    # Light gray to bright blue with stronger differentiation
    r1, g1, b1 = 0xc0, 0xc0, 0xc0  # much lighter gray
    r2, g2, b2 = 0x1f, 0x77, 0xb4  # bright blue
    
    r = int(r1 + (r2 - r1) * normalized_weight)
    g = int(g1 + (g2 - g1) * normalized_weight)
    b = int(b1 + (b2 - b1) * normalized_weight)
    color = f"#{r:02x}{g:02x}{b:02x}"
    
    # More dramatic alpha changes: 0.4 to 0.95
    alpha = 0.4 + 0.55 * normalized_weight
    
    return color, alpha

def draw_edges(G, pos, ax):
    """Draw edges with simple arrows, weight-based colors and sizes."""
    edges = list(G.edges(data=True))
    if not edges:
        return
    
    # Get edge weights and normalize them
    edge_weights = [edge[2]['weight'] for edge in edges]
    max_weight = max(edge_weights) if edge_weights else 1.0
    
    # Draw each edge
    for i, (u, v, data) in enumerate(edges):
        weight = data['weight']
        normalized_weight = weight / max_weight
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        
        # Get weight-based color and alpha
        color, alpha = get_edge_color(normalized_weight)
        
        # Weight-based line width
        line_width = EDGE_BASE_WIDTH + normalized_weight * (EDGE_MAX_WIDTH_MULTIPLIER - EDGE_BASE_WIDTH)
        
        # Weight-based arrow size with more dramatic scaling
        arrow_size = ARROW_MUTATION_SCALE * (0.5 + normalized_weight)  # Scale from 50% to 150% of base
        
        if u == v:  # Self-loop
            # Create a proper circular self-loop
            import matplotlib.patches as patches
            
            # Calculate circle radius based on weight (bigger circles for higher weights)
            circle_radius = 0.08 + normalized_weight * 0.04  # 0.08 to 0.12
            
            # Position circle above the node
            circle_center_x = x0
            circle_center_y = y0 + circle_radius + 0.05
            
            # Create a circle (no fill, just outline)
            circle = patches.Circle(
                (circle_center_x, circle_center_y), 
                circle_radius,
                fill=False,
                linewidth=line_width,
                edgecolor=color,
                alpha=alpha,
                zorder=EDGE_Z_ORDER
            )
            ax.add_patch(circle)
            
            # Add an arrowhead at the top-right of the circle, tilted for better flow
            import math
            
            # Fixed arrow position for all self-loops - slightly tilted towards bottom
            angle = math.pi/3  # 60 degrees - slightly more towards bottom for better visibility
                
            arrow_x = circle_center_x + circle_radius * math.sin(angle)
            arrow_y = circle_center_y + circle_radius * math.cos(angle)
            
            # Create a tilted arrow pointing clockwise (tangent to circle)
            # For clockwise tangent: rotate radius vector by -90 degrees
            # Then rotate the arrowhead slightly more to the right
            tangent_angle = angle - math.pi/2  # Convert to tangent direction
            # Weight-based rotation: 165 degrees for light weights, 170 degrees for heavy weights
            base_rotation = 11 * math.pi/12    # 165 degrees for light weights
            max_rotation = 17 * math.pi/18     # 170 degrees for heavy weights  
            rotation_offset = base_rotation + normalized_weight * (max_rotation - base_rotation)
            arrow_direction = tangent_angle + rotation_offset
            
            dx = 0.02 * math.cos(arrow_direction)  # rotated direction x component
            dy = 0.02 * math.sin(arrow_direction)  # rotated direction y component
            
            # Position arrow directly on the circle circumference
            arrow_x_offset = circle_center_x + circle_radius * math.sin(angle)
            arrow_y_offset = circle_center_y + circle_radius * math.cos(angle)
            
            # Create just the arrowhead without the line part
            arrow_head = FancyArrowPatch(
                (arrow_x_offset, arrow_y_offset),  # Start and end at same point
                (arrow_x_offset - dx * 0.1, arrow_y_offset - dy * 0.1),  # Tiny offset for direction
                arrowstyle="->",
                linewidth=line_width,
                alpha=alpha,
                color=color,
                joinstyle="round",
                capstyle="round", 
                mutation_scale=arrow_size,  # Same scale as regular edges
                zorder=EDGE_Z_ORDER + 1
            )
            ax.add_patch(arrow_head)
        else:
            # Regular edge - using straight lines for now
            # Draw line without arrow
            line = FancyArrowPatch(
                (x0, y0), (x1, y1),
                connectionstyle="Arc3,rad=0",  # Straight line
                arrowstyle="-",  # No arrow on the line
                linewidth=line_width,
                alpha=alpha,
                color=color,
                joinstyle="round",
                capstyle="round",
                zorder=EDGE_Z_ORDER
            )
            ax.add_patch(line)

            # Calculate 3/4 point along the straight edge for arrow placement
            import math
            t = 0.75

            arrow_x = x0 + t * (x1 - x0)
            arrow_y = y0 + t * (y1 - y0)

            # Direction is simply the line direction
            arrow_dx = x1 - x0
            arrow_dy = y1 - y0

            # Normalize direction
            arrow_length = math.sqrt(arrow_dx**2 + arrow_dy**2)
            if arrow_length > 0:
                arrow_dx /= arrow_length
                arrow_dy /= arrow_length

            # Create small arrow at 3/4 point
            arrow_scale = 0.015  # Small arrow length
            arrow_start_x = arrow_x - arrow_dx * arrow_scale * 0.5
            arrow_start_y = arrow_y - arrow_dy * arrow_scale * 0.5
            arrow_end_x = arrow_x + arrow_dx * arrow_scale * 0.5
            arrow_end_y = arrow_y + arrow_dy * arrow_scale * 0.5

            arrow = FancyArrowPatch(
                (arrow_start_x, arrow_start_y),
                (arrow_end_x, arrow_end_y),
                arrowstyle="->",
                linewidth=line_width,
                alpha=alpha,
                color=color,
                joinstyle="round",
                capstyle="round",
                mutation_scale=arrow_size,
                zorder=EDGE_Z_ORDER + 1  # Draw arrow on top of line
            )
            ax.add_patch(arrow)

def create_circular_network_plot(predictions_file: str, threshold: float = None, save_path: str = None):
    """Create a circular network plot with model logos."""
    
    if threshold is None:
        threshold = DEFAULT_THRESHOLD
    
    if not Path(predictions_file).exists():
        print(f"Warning: File not found {predictions_file}")
        return None
    
    # Load prediction data
    prediction_data, model_totals, all_models = load_prediction_data(predictions_file)
    if prediction_data is None:
        print("No prediction data found")
        return None
    
    # Create graph
    G, edge_data = create_graph_from_data(prediction_data, model_totals, all_models, threshold)
    if G.number_of_edges() == 0:
        print(f"No prediction patterns above {threshold}% threshold found")
        return None
    
    # Create plot with equal aspect ratio to maintain circular layout
    fig, ax = plt.subplots(figsize=FIGURE_SIZE)
    ax.set_aspect('equal')
    
    # Create circular layout
    pos = create_circular_layout(G)
    
    # Add logos/emojis and text labels for each node
    successful_logos = 0
    for node, (x, y) in pos.items():
        if add_logo_or_emoji(ax, node, x, y):
            successful_logos += 1
        add_text_label(ax, node, x, y)
    
    # Draw edges
    draw_edges(G, pos, ax)
    
    # Clean formatting
    ax.axis('off')
    ax.set_xlim(PLOT_LIMITS[0], PLOT_LIMITS[1])
    ax.set_ylim(PLOT_LIMITS[0], PLOT_LIMITS[1])
    plt.tight_layout()
    
    # Print statistics
    print(f"\n📊 Network Statistics:")
    print(f"   • Nodes (models): {G.number_of_nodes()}")
    print(f"   • Edges (prediction patterns >{threshold}%): {G.number_of_edges()}")
    print(f"   • Real logos loaded: {successful_logos}/{len(pos)}")
    print(f"   • Emoji fallbacks: {len(pos) - successful_logos}")
    
    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        
        # Save PDF version
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"\nNetwork graph saved to {save_path}")
        
        # Also save PNG version for social media
        png_path = save_path.replace('.pdf', '.png')
        plt.savefig(png_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"PNG version saved to {png_path}")
    
    return fig

def main():
    """Generate network plots for both corpora and both hint conditions."""
    
    plot_configs = [
        {
            "name": "100-word corpus (no hints)",
            "predictions": "results/predictions/predictions_exact_model.jsonl",
            "output": "results/plot_100/prediction_network_circular_clean.pdf"
        },
        {
            "name": "100-word corpus (with hints)",
            "predictions": "results/predictions/predictions_exact_model_with_hints.jsonl",
            "output": "results/plot_100/prediction_network_circular_clean_with_hints.pdf"
        },
        {
            "name": "500-word corpus (no hints)",
            "predictions": "results/predictions_500/predictions_exact_model.jsonl",
            "output": "results/plot_500/prediction_network_circular_clean.pdf"
        },
        {
            "name": "500-word corpus (with hints)",
            "predictions": "results/predictions_500/predictions_exact_model_with_hints.jsonl",
            "output": "results/plot_500/prediction_network_circular_clean_with_hints.pdf"
        }
    ]
    
    for config in plot_configs:
        if Path(config["predictions"]).exists():
            print(f"\nGenerating {config['name']} network plot...")
            create_circular_network_plot(config["predictions"], threshold=3.0, save_path=config["output"])
        else:
            print(f"⚠️ Skipping {config['name']} - file not found: {config['predictions']}")
    
    print("\n✅ Network plot generation complete!")

if __name__ == "__main__":
    main()