"""
Visualize trajectories and Lyapunov functions for different models
"""

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
import lightning.pytorch as pl
from compare_tracking_performance import (
    load_trained_model,
    load_pretrained_model,
    simulate_microgrid
)


def compute_lyapunov_values(trajectory, V_net, num_inverters, system, device):
    """
    Compute Lyapunov function values for trajectory
    
    Args:
        trajectory: [time_steps+1, num_inverters, 3] (delta, omega, xi)
        V_net: VectorLyapunovNetwork
        num_inverters: Number of inverters
        system: MicrogridFormationDynamics system
        device: torch device
        
    Returns:
        V_values: [time_steps+1, num_inverters] Lyapunov values for each inverter
    """
    if V_net is None:
        return None
    
    omega_star = system.omega_star
    equilibrium = torch.zeros(num_inverters * 3, device=device)
    for i in range(num_inverters):
        equilibrium[i * 3 + 1] = omega_star  # omega = omega_star
    
    V_values = []
    
    with torch.no_grad():
        for t in range(trajectory.shape[0]):
            # Convert to tensor
            states_t = torch.tensor([trajectory[t]], dtype=torch.float32, device=device)  # [1, num_inverters, 3]
            
            # Compute error state
            states_flatten = states_t.reshape(1, -1)  # [1, num_inverters*3]
            equilibrium_expanded = equilibrium.unsqueeze(0).expand(1, -1)  # [1, num_inverters*3]
            error_state = states_flatten - equilibrium_expanded  # [1, num_inverters*3]
            
            # Use relative phase differences: delta absolute value doesn't matter for power flow
            # Only relative phase differences between inverters matter
            # Subtract first inverter's delta error from all to get relative phase differences
            delta_ref = error_state[0, 0].clone()  # Reference delta error (first inverter)
            for i in range(num_inverters):
                error_state[0, i * 3] = error_state[0, i * 3] - delta_ref
            
            # Compute Lyapunov values
            V_t = V_net(error_state)  # [1, num_inverters]
            V_values.append(V_t.squeeze(0).cpu().numpy())  # [num_inverters]
    
    return np.array(V_values)  # [time_steps+1, num_inverters]


def visualize_trajectory(trajectory, model_name, output_dir, 
                         num_inverters, time_steps, dt=0.01, omega_star=None):
    """
    Visualize frequency trajectory only (beautiful style)
    
    Args:
        trajectory: [time_steps+1, num_inverters, 3] (delta, omega, xi)
        model_name: Name of the model
        output_dir: Output directory
        num_inverters: Number of inverters
        time_steps: Number of time steps
        dt: Time step size
        omega_star: Reference frequency (nominal frequency)
    """
    time = np.arange(trajectory.shape[0]) * dt
    
    # Create figure with beautiful style
    try:
        plt.style.use('seaborn-v0_8-darkgrid')
    except:
        try:
            plt.style.use('seaborn-darkgrid')
        except:
            plt.style.use('default')
    
    fig, ax = plt.subplots(figsize=(12, 8))
    fig.patch.set_facecolor('white')
    
    # Define beautiful color palette
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', 
              '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    
    # Inverter labels
    inverter_labels = ['Inverter 0', 'Inverter 1', 'Inverter 2', 'Inverter 3', 'Inverter 4', 
                      'Inverter 5', 'Inverter 6', 'Inverter 7', 'Inverter 8', 'Inverter 9']
    
    # Plot frequency trajectories with beautiful styling
    for i in range(num_inverters):
        color = colors[i % len(colors)]
        label = inverter_labels[i] if i < len(inverter_labels) else f'Inverter {i}'
        linewidth = 2.5 if i == 0 else 2.0  # Thicker line for first inverter
        ax.plot(time, trajectory[:, i, 1], label=label, 
                color=color, linewidth=linewidth, alpha=0.85, 
                marker='', linestyle='-')
    
    # Add desired frequency reference line
    if omega_star is not None:
        ax.axhline(y=omega_star, color='#e74c3c', linestyle='--', linewidth=2.5, 
                   alpha=0.8, label='Nominal Frequency', zorder=10)
    
    # Beautify the plot - no title, larger fonts
    ax.set_xlabel('Time (s)', fontsize=27)
    ax.set_ylabel('Frequency (rad/s)', fontsize=27)
    
    # Legend with background box
    legend = ax.legend(loc='best', fontsize=27, framealpha=1.0, 
                       fancybox=True, shadow=False, ncol=2, frameon=True)
    
    # Grid with better styling
    ax.grid(True, alpha=0.4, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    
    # Set axis limits with some padding
    y_min = np.min(trajectory[:, :, 1]) - 1
    y_max = np.max(trajectory[:, :, 1]) + 1
    omega_star_val = omega_star if omega_star is not None else 2 * np.pi * 50
    ax.set_ylim([max(omega_star_val - 5, y_min), min(omega_star_val + 5, y_max)])
    
    # Beautify spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.5)
    ax.spines['bottom'].set_linewidth(1.5)
    
    # Set tick parameters - larger font
    ax.tick_params(axis='both', which='major', labelsize=27, width=1.5, length=5)
    ax.tick_params(axis='both', which='minor', labelsize=23, width=1, length=3)
    
    plt.tight_layout()
    
    # Save figure with high quality (PNG and PDF)
    output_path_png = os.path.join(output_dir, f'{model_name}_trajectory.png')
    output_path_pdf = os.path.join(output_dir, f'{model_name}_trajectory.pdf')
    plt.savefig(output_path_png, dpi=300, bbox_inches='tight', facecolor='white', edgecolor='none')
    plt.savefig(output_path_pdf, bbox_inches='tight', facecolor='white', edgecolor='none')
    print(f"  Saved: {output_path_png}")
    print(f"  Saved: {output_path_pdf}")
    plt.close()


def visualize_lyapunov_contour(V_net, model_name, output_dir, num_inverters, system, device,
                                delta_range=(-0.5, 0.5), omega_range=(-2.0, 2.0),
                                resolution=100, inverter_idx=1):
    """
    Visualize Lyapunov function as contour plot for a specific inverter
    
    Args:
        V_net: VectorLyapunovNetwork
        model_name: Name of the model
        output_dir: Output directory
        num_inverters: Number of inverters
        system: MicrogridFormationDynamics system
        device: torch device
        delta_range: Tuple of (min, max) phase angle error values
        omega_range: Tuple of (min, max) frequency error values
        resolution: Resolution of the grid (number of points per axis)
        inverter_idx: Index of the inverter to visualize (1 for first controlled inverter)
    """
    if V_net is None:
        return
    
    omega_star = system.omega_star
    equilibrium = torch.zeros(num_inverters * 3, device=device)
    for i in range(num_inverters):
        equilibrium[i * 3 + 1] = omega_star  # omega = omega_star
    
    # Create grid for phase angle error (delta) and frequency error (omega)
    delta_vals = np.linspace(delta_range[0], delta_range[1], resolution)
    omega_vals = np.linspace(omega_range[0], omega_range[1], resolution)
    delta_grid, omega_grid = np.meshgrid(delta_vals, omega_vals)
    
    # Initialize states at equilibrium
    equilibrium_states = np.zeros((num_inverters, 3))
    equilibrium_states[:, 1] = omega_star  # omega = omega_star
    
    # Compute Lyapunov values for each point in grid
    V_grid = np.zeros_like(delta_grid)
    
    print(f"    Computing Lyapunov function values for inverter {inverter_idx}...")
    with torch.no_grad():
        for i in range(resolution):
            for j in range(resolution):
                # Create state with current grid point for the target inverter
                states = equilibrium_states.copy()
                
                # Apply phase angle and frequency error
                states[inverter_idx, 0] = delta_grid[i, j]  # delta error
                states[inverter_idx, 1] = omega_star + omega_grid[i, j]  # omega = omega_star + error
                states[inverter_idx, 2] = 0.0  # xi = 0 (equilibrium)
                
                # Convert to tensor
                states_tensor = torch.tensor([states], dtype=torch.float32, device=device)  # [1, num_inverters, 3]
                
                # Compute error state
                states_flatten = states_tensor.reshape(1, -1)  # [1, num_inverters*3]
                equilibrium_expanded = equilibrium.unsqueeze(0).expand(1, -1)  # [1, num_inverters*3]
                error_state = states_flatten - equilibrium_expanded  # [1, num_inverters*3]
                
                # Compute Lyapunov values
                V_t = V_net(error_state)  # [1, num_inverters]
                
                # Get value for the target inverter
                V_grid[i, j] = V_t[0, inverter_idx].cpu().item()
    
    # Create contour plot
    fig, ax = plt.subplots(figsize=(10, 8))
    fig.patch.set_facecolor('white')
    
    # Create contour plot with filled contours
    contour = ax.contourf(delta_grid, omega_grid, V_grid, levels=50, cmap='viridis')
    
    # Add contour lines
    contour_lines = ax.contour(delta_grid, omega_grid, V_grid, levels=20, 
                               colors='white', alpha=0.3, linewidths=0.5)
    ax.clabel(contour_lines, inline=True, fontsize=17, fmt='%.2f')
    
    # Mark equilibrium point
    ax.plot(0.0, 0.0, 'r*', markersize=20, markeredgecolor='white', 
            markeredgewidth=2, label='Equilibrium', zorder=10)
    
    # Add colorbar with larger font
    cbar = plt.colorbar(contour, ax=ax)
    cbar.ax.tick_params(labelsize=23)
    
    # Labels - no title, larger fonts
    ax.set_xlabel('Phase Angle Error (rad)', fontsize=27)
    ax.set_ylabel('Frequency Error (rad/s)', fontsize=27)
    
    # Legend with background box
    legend = ax.legend(loc='upper right', fontsize=27, framealpha=1.0,
                       fancybox=True, shadow=False, frameon=True)
    
    ax.grid(True, alpha=0.3)
    
    # Set tick parameters - larger font
    ax.tick_params(axis='both', which='major', labelsize=23, width=1.5, length=5)
    ax.tick_params(axis='both', which='minor', labelsize=19, width=1, length=3)
    
    plt.tight_layout()
    
    # Save figure (PNG and PDF)
    output_path_png = os.path.join(output_dir, f'{model_name}_lyapunov_inverter{inverter_idx}.png')
    output_path_pdf = os.path.join(output_dir, f'{model_name}_lyapunov_inverter{inverter_idx}.pdf')
    plt.savefig(output_path_png, dpi=300, bbox_inches='tight')
    plt.savefig(output_path_pdf, bbox_inches='tight')
    print(f"  Saved: {output_path_png}")
    print(f"  Saved: {output_path_pdf}")
    plt.close()


def visualize_all_models(num_inverters=3, controlled_indices=None, time_steps=500, 
                         delay_time_step=10, device=None, 
                         pre_trained_model=None):
    """
    Visualize trajectories and Lyapunov functions for all models
    
    Args:
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        time_steps: Number of simulation steps
        delay_time_step: Delay time steps
        device: torch device
        pre_trained_model: Path to pre-trained controller (not used, uses default paths)
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if controlled_indices is None:
        controlled_indices = list(range(1, num_inverters))
    
    # Create output directory
    output_dir = 'visualization_results'
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output directory: {output_dir}")
    
    # Dynamics parameters
    dynamics_params = {
        'dt': 0.01,  # Time step (s)
        'n': num_inverters,  # Number of inverters
        'omega_star': 2 * np.pi * 50,  # Nominal frequency (50 Hz)
        'tau': [1.4895] * num_inverters,  # Time constants
        'eta': [6.3509e-4] * num_inverters,  # Droop gains
        'k': [4.9481] * num_inverters,  # Secondary control gains
        'V': [325.3] * num_inverters,  # Voltage magnitudes
        'B': np.array([[0.1 if abs(i - j) == 1 else 0 for j in range(num_inverters)] for i in range(num_inverters)]),  # Line susceptances (chain topology)
        'P_L': [1260.0] * num_inverters,  # Load powers
        'P_star': [1260.0] * num_inverters  # Desired power injections
    }
    
    state_dims = [3] * num_inverters
    control_dims = [1] * num_inverters
    
    # Model names
    model_names = ['pre_trained', 'delayed_siss', 'compositional_iss', 'siss']
    
    # Load all models
    print("\nLoading models...")
    models = {}
    
    # Load pre-trained baseline model
    try:
        controller, system, _ = load_pretrained_model(
            num_inverters, controlled_indices, state_dims, control_dims,
            dynamics_params, device, delay_time_step, pre_trained_model
        )
        models['pre_trained'] = {
            'controller': controller,
            'system': system,
            'V_net': None
        }
        print(f"  ✓ Loaded pre_trained (baseline)")
    except Exception as e:
        print(f"  ✗ Failed to load pre_trained: {e}")
        models['pre_trained'] = None
    
    # Load the three trained models
    for mode in range(3):
        mode_name = ['delayed_siss', 'compositional_iss', 'siss'][mode]
        try:
            controller, system, V_net = load_trained_model(
                mode, num_inverters, controlled_indices, state_dims, control_dims,
                dynamics_params, device, delay_time_step, 
                pre_trained_model
            )
            models[mode_name] = {
                'controller': controller,
                'system': system,
                'V_net': V_net
            }
            print(f"  ✓ Loaded {mode_name}")
        except Exception as e:
            print(f"  ✗ Failed to load {mode_name}: {e}")
            models[mode_name] = None
    
    # Test scenario: initial error only (no disturbances)
    omega_star = dynamics_params['omega_star']
    initial_states = np.zeros((num_inverters, 3))
    initial_states[:, 1] = omega_star  # Initialize at nominal frequency
    
    # Set initial error for first inverter (inverter 0)
    initial_states[0, 0] = 0.0  # delta error: 0.2 rad (phase angle error)
    initial_states[0, 1] = omega_star+2  # omega error: +1.0 rad/s (frequency error)
    initial_states[0, 2] = 0.0  # xi: 0 (integral term starts from 0)
    
    scenario = {
        'initial_states': initial_states,
        'disturbances': [
            np.zeros((num_inverters, 3))  # No disturbances, only initial error
            for t in range(time_steps)
        ]
    }
    
    # Run simulations and visualize
    print(f"\nRunning simulation with initial error only...")
    
    for model_name in model_names:
        if models[model_name] is None:
            continue
        
        print(f"\nProcessing {model_name}...")
        try:
            # Simulate
            trajectory, controls = simulate_microgrid(
                models[model_name]['controller'],
                models[model_name]['system'],
                num_inverters,
                controlled_indices,
                scenario['initial_states'],
                scenario['disturbances'],
                time_steps,
                delay_time_step,
                device
            )
            
            # Compute Lyapunov values
            V_values = compute_lyapunov_values(
                trajectory, 
                models[model_name]['V_net'],
                num_inverters,
                models[model_name]['system'],
                device
            )
            
            # Visualize trajectory
            visualize_trajectory(
                trajectory, model_name, output_dir,
                num_inverters, time_steps, dt=0.01,
                omega_star=omega_star
            )
            
            # Visualize Lyapunov function as contour plot (for each controlled inverter)
            if models[model_name]['V_net'] is not None:
                for inverter_idx in range(1, num_inverters):  # Start from 1 (first controlled inverter)
                    visualize_lyapunov_contour(
                        models[model_name]['V_net'],
                        model_name,
                        output_dir,
                        num_inverters,
                        models[model_name]['system'],
                        device,
                        delta_range=(-0.5, 0.5),
                        omega_range=(-2.0, 2.0),
                        resolution=100,
                        inverter_idx=inverter_idx
                    )
            
            print(f"  ✓ Completed visualization for {model_name}")
            
        except Exception as e:
            print(f"  ✗ Failed: {e}")
            import traceback
            traceback.print_exc()
    
    print(f"\nAll visualizations saved to: {output_dir}/")


if __name__ == "__main__":
    # Configuration
    num_inverters = 3
    controlled_indices = list(range(1, num_inverters))  # Control all inverters except the first one
    time_steps = 500  # 5 seconds at dt=0.01
    delay_time_step = 10  # 0.1 seconds delay
    
    # Pre-trained model path (not used, uses default paths)
    pre_trained_model = None
    
    # Check if pre-trained models exist
    if not any(os.path.exists(f"pre_train_model/control_model_{i}.pth") for i in range(3)):
        pre_trained_model = None
        print("Warning: Pre-trained models not found. Skipping pre_trained baseline.")
    
    # Run visualization
    visualize_all_models(
        num_inverters=num_inverters,
        controlled_indices=controlled_indices,
        time_steps=time_steps,
        delay_time_step=delay_time_step,
        pre_trained_model=pre_trained_model
    )
    
    print("\nVisualization completed!")



