#!/usr/bin/env python3
"""
Generate architecture + comparison diagrams for the post-quantum stablecoin paper.
Outputs PNGs in post-quantum-stablecoin-paper/figures/
"""

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch, ConnectionPatch

# --- Globals ---
OUTPUT_DIR = os.path.join('post-quantum-stablecoin-paper', 'figures')
os.makedirs(OUTPUT_DIR, exist_ok=True)

COLORS = {
    'classical': '#FF6B6B',
    'quantum': '#4ECDC4',
    'hybrid': '#45B7D1',
    'infrastructure': '#96CEB4',
    'user': '#FFEAA7'
}


def create_architecture_diagram():
    """Create the post-quantum stablecoin architecture diagram."""
    fig, ax = plt.subplots(1, 1, figsize=(14, 10), constrained_layout=True)

    # User Layer
    user_box = FancyBboxPatch((1, 8), 12, 1.5,
                              boxstyle="round,pad=0.1",
                              facecolor=COLORS['user'],
                              edgecolor='black', linewidth=2)
    ax.add_patch(user_box)
    ax.text(7, 8.75, 'User Layer\n(Wallets, DApps, Exchanges)',
            ha='center', va='center', fontsize=12, fontweight='bold')

    # Application Layer
    app_boxes = [
        (1, 6.5, 'Wallet\nApplications'),
        (5, 6.5, 'Trading\nPlatforms'),
        (9, 6.5, 'DeFi\nProtocols')
    ]
    for x, y, label in app_boxes:
        box = FancyBboxPatch((x, y), 3, 1,
                             boxstyle="round,pad=0.1",
                             facecolor=COLORS['hybrid'],
                             edgecolor='black', linewidth=1)
        ax.add_patch(box)
        ax.text(x+1.5, y+0.5, label, ha='center', va='center', fontsize=10)

    # Post-Quantum Security Layer
    pq_box = FancyBboxPatch((1, 4.5), 12, 1.5,
                            boxstyle="round,pad=0.1",
                            facecolor=COLORS['quantum'],
                            edgecolor='black', linewidth=2)
    ax.add_patch(pq_box)
    ax.text(7, 5.25, 'Post-Quantum Security Layer\n(CRYSTALS-Dilithium Signatures)',
            ha='center', va='center', fontsize=12, fontweight='bold')

    # Smart Contract Layer
    contract_boxes = [
        (1, 2.5, 'Stablecoin\nContract'),
        (4.5, 2.5, 'Governance\nContract'),
        (8, 2.5, 'Treasury\nContract'),
        (11.5, 2.5, 'Bridge\nContract')
    ]
    for x, y, label in contract_boxes:
        box = FancyBboxPatch((x, y), 2, 1.5,
                             boxstyle="round,pad=0.1",
                             facecolor=COLORS['infrastructure'],
                             edgecolor='black', linewidth=1)
        ax.add_patch(box)
        ax.text(x+1, y+0.75, label, ha='center', va='center', fontsize=9)

    # Blockchain Infrastructure
    blockchain_box = FancyBboxPatch((1, 0.5), 12, 1.5,
                                    boxstyle="round,pad=0.1",
                                    facecolor=COLORS['infrastructure'],
                                    edgecolor='black', linewidth=2)
    ax.add_patch(blockchain_box)
    ax.text(7, 1.25, 'Blockchain Infrastructure\n(Ethereum, Consensus Layer)',
            ha='center', va='center', fontsize=12, fontweight='bold')

    # Flow arrows
    arrows = [((7, 8), (7, 7.5)), ((7, 6.5), (7, 6)),
              ((7, 4.5), (7, 4)), ((7, 2.5), (7, 2))]
    for start, end in arrows:
        arrow = ConnectionPatch(start, end, "data", "data",
                                arrowstyle="->", shrinkA=5, shrinkB=5,
                                mutation_scale=20, fc="black", lw=2)
        ax.add_patch(arrow)

    # Legend outside (right)
    legend_elements = [
        patches.Patch(color=COLORS['user'], label='User Interface'),
        patches.Patch(color=COLORS['hybrid'], label='Application Layer'),
        patches.Patch(color=COLORS['quantum'], label='Post-Quantum Security'),
        patches.Patch(color=COLORS['infrastructure'], label='Infrastructure'),
    ]
    ax.legend(handles=legend_elements, loc='center left',
              bbox_to_anchor=(1.02, 0.5), frameon=False)

    # Canvas settings
    ax.set_xlim(0, 14)
    ax.set_ylim(0, 10)
    ax.set_aspect('equal')
    ax.axis('off')
    ax.set_title('Post-Quantum Stablecoin Architecture',
                 fontsize=16, fontweight='bold', pad=20)

    # Leave right margin for legend
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.savefig(os.path.join(OUTPUT_DIR, 'architecture_diagram.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


def create_performance_comparison():
    """Create performance comparison charts (2x2) with legends outside."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
        2, 2, figsize=(16, 12), constrained_layout=True
    )

    # Signature Generation Time
    algorithms = ['ECDSA', 'Dilithium']
    times = [0.35, 0.82]
    bars1 = ax1.bar(algorithms, times)
    ax1.set_ylabel('Time (ms)')
    ax1.set_title('Signature Generation Time')
    ax1.set_ylim(0, 1.0)
    for bar, time in zip(bars1, times):
        ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.02,
                 f'{time:.2f} ms', ha='center', va='bottom')

    # Signature Size (log scale)
    sizes = [64, 3293]
    bars2 = ax2.bar(algorithms, sizes)
    ax2.set_ylabel('Size (bytes)')
    ax2.set_title('Signature Size (log scale)')
    ax2.set_yscale('log')
    for bar, size in zip(bars2, sizes):
        ax2.text(bar.get_x()+bar.get_width()/2, bar.get_height()*1.1,
                 f'{size} B', ha='center', va='bottom')

    # Transaction Throughput
    operations = ['Transfer', 'Mint', 'Burn']
    ecdsa_tps = [465, 305, 339]
    dilithium_tps = [292, 206, 243]
    x = np.arange(len(operations))
    width = 0.36
    b3 = ax3.bar(x - width/2, ecdsa_tps, width, label='ECDSA')
    b4 = ax3.bar(x + width/2, dilithium_tps, width, label='Dilithium')
    ax3.set_ylabel('Transactions per Second')
    ax3.set_title('Transaction Throughput Comparison')
    ax3.set_xticks(x, operations)
    ax3.legend(loc='upper center', bbox_to_anchor=(0.5, 1.18),
               ncol=2, frameon=False)

    # Gas Cost
    ecdsa_gas = [21000, 46000, 31000]
    dilithium_gas = [29400, 64400, 43400]
    b5 = ax4.bar(x - width/2, ecdsa_gas, width, label='ECDSA')
    b6 = ax4.bar(x + width/2, dilithium_gas, width, label='Dilithium')
    ax4.set_ylabel('Gas Cost')
    ax4.set_title('Gas Cost Comparison')
    ax4.set_xticks(x, operations)
    ax4.legend(loc='upper center', bbox_to_anchor=(0.5, 1.18),
               ncol=2, frameon=False)

    plt.savefig(os.path.join(OUTPUT_DIR, 'performance_comparison.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


def create_migration_timeline():
    """Create migration timeline diagram."""
    fig, ax = plt.subplots(1, 1, figsize=(16, 8), constrained_layout=True)

    # Phases: (label, start, end, color)
    phases = [
        ('Preparation\n(6–12 months)', 0, 12, '#FFEAA7'),
        ('Hybrid Operation\n(12–24 months)', 12, 36, '#74B9FF'),
        ('Full Transition\n(6–12 months)', 36, 42, '#00B894')
    ]

    milestones = [
        (3, 'Infrastructure\nUpdates', '#E17055'),
        (6, 'Wallet\nModifications', '#E17055'),
        (9, 'Developer\nTools', '#E17055'),
        (12, 'Hybrid\nLaunch', '#0984E3'),
        (18, 'User\nMigration', '#0984E3'),
        (24, 'Performance\nOptimization', '#0984E3'),
        (30, 'Security\nAudits', '#0984E3'),
        (36, 'Legacy\nDeprecation', '#00B894'),
        (39, 'Full PQ\nOperation', '#00B894'),
        (42, 'Migration\nComplete', '#00B894')
    ]

    # Draw phases
    for name, start, end, color in phases:
        rect = patches.Rectangle((start, 2), end-start, 2,
                                 facecolor=color, alpha=0.7, edgecolor='black')
        ax.add_patch(rect)
        ax.text((start+end)/2, 3, name, ha='center', va='center',
                fontsize=12, fontweight='bold')

    # Milestones
    for month, label, color in milestones:
        ax.plot([month, month], [1, 2], color=color, linewidth=3)
        ax.plot(month, 1.5, 'o', color=color, markersize=8)
        ax.text(month, 0.45, label, ha='center', va='center',
                fontsize=9, rotation=45)

    # Timeline axis + ticks
    ax.plot([0, 42], [1, 1], 'k-', linewidth=2)
    for i in range(0, 43, 6):
        ax.plot([i, i], [0.9, 1.1], 'k-', linewidth=1)
        ax.text(i, 0.8, f'{i}m', ha='center', va='top', fontsize=8)

    ax.set_xlim(-2, 44)
    ax.set_ylim(0, 5)
    ax.set_xlabel('Timeline (months)', fontsize=12)
    ax.set_title('Post-Quantum Stablecoin Migration Timeline',
                 fontsize=16, fontweight='bold')
    ax.axis('off')
    ax.annotate('', xy=(42, 1), xytext=(0, 1),
                arrowprops=dict(arrowstyle='->', lw=2, color='black'))
    ax.text(21, 0.2, 'Migration Timeline (months)',
            ha='center', va='center', fontsize=12, fontweight='bold')

    plt.savefig(os.path.join(OUTPUT_DIR, 'migration_timeline.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


def create_security_comparison():
    """Create security comparison diagram with outside legends and epsilon bar."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), constrained_layout=True)

    # Panel 1: Security level (secure/vulnerable)
    categories = ['Classical\nAttacks', 'Quantum\nAttacks', 'Long-term\nSecurity']
    ecdsa_security = [1, 0, 0]
    dilithium_security = [1, 1, 1]
    x = np.arange(len(categories))
    width = 0.36

    bars1 = ax1.bar(x - width/2, ecdsa_security, width, label='ECDSA', alpha=0.9)
    bars2 = ax1.bar(x + width/2, dilithium_security, width, label='Dilithium', alpha=0.9)

    ax1.set_ylabel('Security Level')
    ax1.set_title('Security Comparison')
    ax1.set_xticks(x, categories)
    ax1.set_ylim(0, 1.25)
    for bar, v in zip(bars1, ecdsa_security):
        ax1.text(bar.get_x()+bar.get_width()/2, (v if v > 0 else 0.03)+0.02,
                 'Secure' if v == 1 else 'Vulnerable',
                 ha='center', va='bottom', fontsize=9, rotation=35)
    for bar, v in zip(bars2, dilithium_security):
        ax1.text(bar.get_x()+bar.get_width()/2, v+0.02,
                 'Secure', ha='center', va='bottom', fontsize=9, rotation=35)

    ax1.legend(loc='upper center', bbox_to_anchor=(0.5, 1.18), ncol=2, frameon=False)

    # Panel 2: Security bits (epsilon for "broken")
    algorithms = ['ECDSA\n(Classical)', 'ECDSA\n(Quantum)', 'Dilithium\n(Classical)', 'Dilithium\n(Quantum)']
    eps = 0.001
    security_bits = [128, 0, 192, 192]
    bars3 = ax2.bar(algorithms, security_bits, alpha=0.9)
    ax2.set_ylabel('Security Bits')
    ax2.set_title('Security Strength Comparison')
    ax2.set_ylim(0, 210)
    for bar, bits in zip(bars3, security_bits):
        if bits > eps:
            ax2.text(bar.get_x()+bar.get_width()/2, bits+5,
                     f'{bits:.0f} bits', ha='center', va='bottom', fontsize=9)
        else:
            ax2.text(bar.get_x()+bar.get_width()/2, 10,
                     'Broken under\nquantum', ha='center', va='bottom',
                     fontsize=9, color='red')

    plt.savefig(os.path.join(OUTPUT_DIR, 'security_comparison.png'),
                dpi=300, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    print("Generating architecture diagram...")
    create_architecture_diagram()

    print("Generating performance comparison...")
    create_performance_comparison()

    print("Generating migration timeline...")
    create_migration_timeline()

    print("Generating security comparison...")
    create_security_comparison()

    print(f"All diagrams generated in: {OUTPUT_DIR}")
