"""
Compare tracking performance of five models:
- mode 0: delayed_siss (best_microgrid_model)
- mode 1: compositional_iss (best_microgrid_model_ISS)
- mode 2: siss (best_microgrid_model_sISS)
- pre_trained: Original pre-trained controller (baseline without delay compensation)
- predictor_feedback: Predictor Feedback controller (baseline with delay compensation)
"""

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
import os
import lightning.pytorch as pl
from training_exp_comb import (
    train_model, 
    MicrogridFormationDynamics,
    StringStabilityTrainer
)
from networks import VectorLyapunovNetwork, CombinedController
from baseline import (
    PredictorFeedbackController,
    load_pretrained_controller as load_pretrained_controller_baseline,
    simulate_with_predictor_feedback
)


def load_trained_model(mode, num_inverters, controlled_indices, state_dims, control_dims, 
                      dynamics_params, device, delay_time_step=0.1, 
                      pre_trained_model=None):
    """
    Load a trained model from checkpoint
    
    Args:
        mode: 0 for delayed_siss, 1 for compositional_iss, 2 for siss
        ... other parameters same as train_model
    """
    # Create connection matrix for microgrid
    connection_matrix = {}
    for i in range(num_inverters):
        connection_matrix[i] = {}
        connection_matrix[i][i] = 0.01    # Self connection
        if i > 0:  # All inverters except reference are connected to their predecessor
            connection_matrix[i][i-1] = 0.01  # Connection to previous inverter
        if i < num_inverters - 1:  # All inverters except last are connected to their successor
            connection_matrix[i][i+1] = 0.01  # Connection to next inverter
    
    # Initialize system dynamics
    system = MicrogridFormationDynamics(dynamics_params, connection_matrix)
    
    # Create binary adjacency matrix
    G = torch.zeros(len(system.connections), len(system.connections)).to(device)
    for i in system.connections:
        for j in system.connections[i]:
            G[i, j] = 0.01
    
    # Initialize networks
    V_net = VectorLyapunovNetwork(
        input_dim=3,
        hidden_dim=64,
        G=G,
        num_inverters=num_inverters
    ).to(device)
    
    controller = CombinedController(
        input_dim=3,
        output_dim=1,
        num_inverters=num_inverters
    ).to(device)
    
    original_controller = CombinedController(
        input_dim=3,
        output_dim=1,
        num_inverters=num_inverters
    ).to(device)
    
    # Load pre-trained controller if provided
    if pre_trained_model is not None or True:  # Always try to load pre-trained
        if num_inverters == 3:
            if os.path.exists("pre_train_model/control_model_0.pth"):
                controller.controller_1.load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
                original_controller.controller_1.load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
            if os.path.exists("pre_train_model/control_model_1.pth"):
                controller.controller_2.load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
                original_controller.controller_2.load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
            if os.path.exists("pre_train_model/control_model_2.pth"):
                controller.controller_3.load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
                original_controller.controller_3.load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
        else:
            # For other numbers of inverters
            if os.path.exists("pre_train_model/control_model_0.pth"):
                controller.controllers[0].load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
                original_controller.controllers[0].load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
            if num_inverters > 1 and os.path.exists("pre_train_model/control_model_2.pth"):
                controller.controllers[num_inverters - 1].load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
                original_controller.controllers[num_inverters - 1].load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
            if os.path.exists("pre_train_model/control_model_1.pth"):
                for i in range(1, num_inverters - 1):
                    controller.controllers[i].load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
                    original_controller.controllers[i].load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
    
    # Load checkpoint using PyTorch Lightning
    if mode == 0:
        checkpoint_path = 'model_weights/best_microgrid_model.ckpt'
    elif mode == 1:
        checkpoint_path = 'model_weights/best_microgrid_model_ISS.ckpt'
    else:
        checkpoint_path = 'model_weights/best_microgrid_model_sISS.ckpt'
    
    if not os.path.exists(checkpoint_path):
        print(f"Warning: Checkpoint {checkpoint_path} not found. Using untrained model.")
        return controller, system, V_net
    
    # Create a dummy trainer to load checkpoint
    trainer = StringStabilityTrainer(
        controller, system, V_net,
        learning_rate=1e-3,
        original_controller=original_controller,
        device=device,
        mode=mode,
        delay_time_step=int(delay_time_step / dynamics_params.get('dt', 0.01))
    )
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if 'state_dict' in checkpoint:
        trainer.load_state_dict(checkpoint['state_dict'])
    else:
        trainer.load_state_dict(checkpoint)
        
    # Extract loaded models
    controller = trainer.controller.to(device)
    V_net = trainer.V_net.to(device)
    
    # Update system connections using V_net
    new_matrix = V_net.coupling_matrix(G)
    new_connections = {}
    for i in range(new_matrix.size(0)):
        new_connections[i] = {}
        for j in range(new_matrix.size(1)):
            val = new_matrix[i, j].item()
            if abs(val) > 1e-9:
                new_connections[i][j] = val
    system.connections = new_connections
    
    return controller, system, V_net


def load_pretrained_model(num_inverters, controlled_indices, state_dims, control_dims, 
                          dynamics_params, device, delay_time_step=0.1, 
                          pre_trained_model=None):
    """
    Load only the pre-trained controller (baseline model without Lyapunov training)
    
    Args:
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        state_dims: State dimensions
        control_dims: Control dimensions
        dynamics_params: Dynamics parameters
        device: torch device
        delay_time_step: Delay time steps
        pre_trained_model: Path to pre-trained controller (not used for microgrid, uses default paths)
        
    Returns:
        controller, system, None (no V_net for baseline)
    """
    # Create connection matrix for microgrid
    connection_matrix = {}
    for i in range(num_inverters):
        connection_matrix[i] = {}
        connection_matrix[i][i] = 0.01    # Self connection
        if i > 0:  # All inverters except reference are connected to their predecessor
            connection_matrix[i][i-1] = 0.01  # Connection to previous inverter
        if i < num_inverters - 1:  # All inverters except last are connected to their successor
            connection_matrix[i][i+1] = 0.01  # Connection to next inverter
    
    # Initialize system dynamics
    system = MicrogridFormationDynamics(dynamics_params, connection_matrix)
    
    # Initialize controller
    controller = CombinedController(
        input_dim=3,
        output_dim=1,
        num_inverters=num_inverters
    ).to(device)
    
    # Load pre-trained controller
    if num_inverters == 3:
        if os.path.exists("pre_train_model/control_model_0.pth"):
            controller.controller_1.load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_1.pth"):
            controller.controller_2.load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_2.pth"):
            controller.controller_3.load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
    else:
        # For other numbers of inverters
        if os.path.exists("pre_train_model/control_model_0.pth"):
            controller.controllers[0].load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
        if num_inverters > 1 and os.path.exists("pre_train_model/control_model_2.pth"):
            controller.controllers[num_inverters - 1].load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_1.pth"):
            for i in range(1, num_inverters - 1):
                controller.controllers[i].load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
    
    if not any(os.path.exists(f"pre_train_model/control_model_{i}.pth") for i in range(3)):
        raise FileNotFoundError("Pre-trained model files not found in pre_train_model/")
    
    return controller, system, None


def simulate_microgrid(controller, system, num_inverters, controlled_indices, 
                       initial_states, disturbances, time_steps, delay_time_step, device):
    """
    Simulate microgrid system with given controller
    
    Args:
        controller: CombinedController network
        system: MicrogridFormationDynamics system
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        initial_states: Initial states [num_inverters, 3] (delta, omega, xi)
        disturbances: Disturbances for each time step [time_steps, num_inverters, 3]
        time_steps: Number of simulation steps
        delay_time_step: Delay time steps
        device: torch device
        
    Returns:
        trajectory: [time_steps+1, num_inverters, 3]
        controls: [time_steps, num_inverters] (control values for controlled inverters, 0 for reference)
    """
    # Reference state (equilibrium)
    # For microgrid: delta=0, omega=omega_star, xi=0 for all inverters
    omega_star = system.omega_star
    equilibrium = torch.zeros(num_inverters * 3, device=device)
    for i in range(num_inverters):
        equilibrium[i * 3 + 1] = omega_star  # omega = omega_star
    
    # Initialize trajectory - convert numpy array to tensor first
    if isinstance(initial_states, np.ndarray):
        initial_states = torch.tensor(initial_states, dtype=torch.float32, device=device)
    else:
        initial_states = initial_states.to(device)
    current_states = initial_states.unsqueeze(0)  # [1, num_inverters, 3]
    trajectory = [current_states.squeeze(0).detach().cpu().numpy()]  # [num_inverters, 3]
    controls_history = []
    state_history = [current_states.squeeze(0).clone()]
    
    # Simulate with no gradient computation
    with torch.no_grad():
        for t in range(time_steps):
            # Prepare controls
            if t < delay_time_step:
                # Initial phase: control is zero
                control = torch.zeros(1, num_inverters, device=device)
                control_values = np.zeros(num_inverters)
            else:
                # After delay: use delayed state for control
                delayed_state = state_history[t - delay_time_step].unsqueeze(0)  # [1, num_inverters, 3]
                
                # Compute error state
                delayed_states_flatten = delayed_state.reshape(1, -1)  # [1, num_inverters*3]
                equilibrium_expanded = equilibrium.unsqueeze(0).expand(1, -1)  # [1, num_inverters*3]
                error_state = delayed_states_flatten - equilibrium_expanded  # [1, num_inverters*3]
                
                # Use relative phase differences: delta absolute value doesn't matter for power flow
                # Only relative phase differences between inverters matter
                # Subtract first inverter's delta error from all to get relative phase differences
                delta_ref = error_state[0, 0].clone()  # Reference delta error (first inverter)
                for i in range(num_inverters):
                    error_state[0, i * 3] = error_state[0, i * 3] - delta_ref
                
                # Compute control
                control = controller(error_state)  # [1, num_inverters]
                
                # Ensure control has correct shape [1, num_inverters]
                if control.dim() == 1:
                    control = control.unsqueeze(0)
                
                # Set control to zero for uncontrolled inverters (e.g., inverter 0 as reference)
                for i in range(num_inverters):
                    if i not in controlled_indices:
                        control[0, i] = 0.0
                
                # Convert to control values for each inverter
                control_values = control[0].cpu().numpy()
            
            # Get disturbance for this time step
            if disturbances is not None and t < len(disturbances):
                dist = torch.tensor([disturbances[t]], dtype=torch.float32, device=device)  # [1, num_inverters, 3]
            else:
                dist = torch.zeros(1, num_inverters, 3, device=device)
            
            # Update states
            current_states = system.next_state(current_states, control, dist, eval=True)
            state_history.append(current_states.squeeze(0).clone())
            trajectory.append(current_states.squeeze(0).detach().cpu().numpy())
            controls_history.append(control_values)
    
    return np.array(trajectory), np.array(controls_history)


def compute_tracking_metrics(trajectory, omega_star, num_inverters):
    """
    Compute tracking performance metrics - only RMSE
    
    Args:
        trajectory: [time_steps+1, num_inverters, 3] (delta, omega, xi)
        omega_star: Reference frequency (nominal frequency)
        num_inverters: Number of inverters
        
    Returns:
        metrics: Dictionary with RMSE metrics
    """
    time_steps, num_inverters, _ = trajectory.shape
    
    # Reference states: delta=0, omega=omega_star, xi=0
    ref_delta = 0.0
    ref_omega = omega_star
    ref_xi = 0.0
    
    # Phase angle, frequency, and controller state errors
    delta_errors = trajectory[:, :, 0] - ref_delta  # [time_steps+1, num_inverters]
    omega_errors = trajectory[:, :, 1] - ref_omega  # [time_steps+1, num_inverters]
    xi_errors = trajectory[:, :, 2] - ref_xi  # [time_steps+1, num_inverters]
    
    # Root Mean Square Error (RMSE)
    delta_rmse = np.sqrt(np.mean(delta_errors**2, axis=0))  # [num_inverters]
    omega_rmse = np.sqrt(np.mean(omega_errors**2, axis=0))  # [num_inverters]
    xi_rmse = np.sqrt(np.mean(xi_errors**2, axis=0))  # [num_inverters]
    
    metrics = {
        'delta_rmse': delta_rmse,
        'omega_rmse': omega_rmse,
        'xi_rmse': xi_rmse,
        'delta_errors': delta_errors,
        'omega_errors': omega_errors,
        'xi_errors': xi_errors
    }
    
    return metrics


def simulate_predictor_feedback(system, controller, num_inverters, controlled_indices,
                                 initial_states, disturbances, time_steps, 
                                 delay_time_step, device):
    """
    Simulate Microgrid system with Predictor Feedback controller.
    
    This function wraps the simulate_with_predictor_feedback from baseline.py
    but creates a new system instance to avoid sharing state.
    """
    return simulate_with_predictor_feedback(
        system, controller, num_inverters, controlled_indices,
        initial_states, disturbances, time_steps, delay_time_step, device
    )


def compare_models(num_inverters=3, controlled_indices=None, time_steps=500, 
                  delay_time_step=10, device=None, 
                  pre_trained_model=None):
    """
    Compare tracking performance of five models
    
    Args:
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        time_steps: Number of simulation steps
        delay_time_step: Delay time steps
        device: torch device
        pre_trained_model: Path to pre-trained controller (not used, uses default paths)
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    if controlled_indices is None:
        controlled_indices = list(range(1, num_inverters))
    
    # Dynamics parameters
    dynamics_params = {
        'dt': 0.01,  # Time step (s)
        'n': num_inverters,  # Number of inverters
        'omega_star': 2 * np.pi * 50,  # Nominal frequency (50 Hz)
        'tau': [1.4895] * num_inverters,  # Time constants
        'eta': [6.3509e-4] * num_inverters,  # Droop gains
        'k': [4.9481] * num_inverters,  # Secondary control gains
        'V': [325.3] * num_inverters,  # Voltage magnitudes
        'B': np.array([[0.1 if abs(i - j) == 1 else 0 for j in range(num_inverters)] for i in range(num_inverters)]),  # Line susceptances (chain topology)
        'P_L': [1260.0] * num_inverters,  # Load powers
        'P_star': [1260.0] * num_inverters  # Desired power injections
    }
    
    state_dims = [3] * num_inverters
    control_dims = [1] * num_inverters
    
    # Model names (including pre-trained baseline and predictor feedback)
    model_names = ['pre_trained', 'predictor_feedback', 'delayed_siss', 'compositional_iss', 'siss']
    
    # Load all models
    print("Loading models...")
    models = {}
    
    # Load pre-trained baseline model first
    try:
        controller, system, _ = load_pretrained_model(
            num_inverters, controlled_indices, state_dims, control_dims,
            dynamics_params, device, delay_time_step, pre_trained_model
        )
        models['pre_trained'] = {
            'controller': controller,
            'system': system,
            'V_net': None
        }
        print(f"  ✓ Loaded pre_trained (baseline without delay compensation)")
    except Exception as e:
        print(f"  ✗ Failed to load pre_trained: {e}")
        models['pre_trained'] = None
    
    # Load predictor feedback model (uses same pre-trained controller but with prediction)
    try:
        # Create a separate system instance for predictor feedback
        connection_matrix_pf = {}
        for i in range(num_inverters):
            connection_matrix_pf[i] = {}
            connection_matrix_pf[i][i] = 0.01
            if i > 0:
                connection_matrix_pf[i][i-1] = 0.01
            if i < num_inverters - 1:
                connection_matrix_pf[i][i+1] = 0.01
        system_pf = MicrogridFormationDynamics(dynamics_params, connection_matrix_pf)
        
        # Load controllers for predictor feedback
        controller_pf = load_pretrained_controller_baseline(
            num_inverters, controlled_indices, device
        )
        
        models['predictor_feedback'] = {
            'controller': controller_pf,
            'system': system_pf,
            'V_net': None,
            'use_predictor': True  # Flag to indicate predictor feedback
        }
        print(f"  ✓ Loaded predictor_feedback (baseline with delay compensation)")
    except Exception as e:
        print(f"  ✗ Failed to load predictor_feedback: {e}")
        models['predictor_feedback'] = None
    
    # Load the three trained models
    for mode in range(3):
        mode_name = ['delayed_siss', 'compositional_iss', 'siss'][mode]
        try:
            controller, system, V_net = load_trained_model(
                mode, num_inverters, controlled_indices, state_dims, control_dims,
                dynamics_params, device, delay_time_step, 
                pre_trained_model
            )
            models[mode_name] = {
                'controller': controller,
                'system': system,
                'V_net': V_net
            }
            print(f"  ✓ Loaded {mode_name}")
        except Exception as e:
            print(f"  ✗ Failed to load {mode_name}: {e}")
            models[mode_name] = None
    
    # Test scenario: initial error only (no disturbances)
    omega_star = dynamics_params['omega_star']
    initial_states = np.zeros((num_inverters, 3))
    initial_states[:, 1] = omega_star  # Initialize at nominal frequency
    
    # Set initial error for first inverter (inverter 0)
    initial_states[0, 0] = 0.0  # delta error: 0.2 rad (phase angle error)
    initial_states[0, 1] = omega_star + 1  # omega error: +1.0 rad/s (frequency error)
    initial_states[0, 2] = 0.0  # xi: 0 (integral term starts from 0)
    
    scenarios = {
        'initial_error': {
            'initial_states': initial_states,
            'disturbances': [
                np.zeros((num_inverters, 3))  # No disturbances, only initial error
                for t in range(time_steps)
            ]
        }
    }
    
    # Run simulations and compute metrics
    all_results = {}
    
    for scenario_name, scenario in scenarios.items():
        print(f"\nTesting scenario: {scenario_name}")
        all_results[scenario_name] = {}
        
        for model_name in model_names:
            if models.get(model_name) is None:
                continue
            
            print(f"  Simulating {model_name}...")
            try:
                # Check if this is the predictor feedback model
                if model_name == 'predictor_feedback':
                    trajectory, controls = simulate_predictor_feedback(
                        models[model_name]['system'],
                        models[model_name]['controller'],
                        num_inverters,
                        controlled_indices,
                        scenario['initial_states'],
                        scenario['disturbances'],
                        time_steps,
                        delay_time_step,
                        device
                    )
                else:
                    trajectory, controls = simulate_microgrid(
                        models[model_name]['controller'],
                        models[model_name]['system'],
                        num_inverters,
                        controlled_indices,
                        scenario['initial_states'],
                        scenario['disturbances'],
                        time_steps,
                        delay_time_step,
                        device
                    )
                
                metrics = compute_tracking_metrics(trajectory, omega_star, num_inverters)
                all_results[scenario_name][model_name] = {
                    'trajectory': trajectory,
                    'controls': controls,
                    'metrics': metrics
                }
                print(f"    ✓ Completed")
            except Exception as e:
                print(f"    ✗ Failed: {e}")
                import traceback
                traceback.print_exc()
                all_results[scenario_name][model_name] = None
    
    # Print summary statistics (no plots)
    print_summary_statistics(all_results, model_names, num_inverters)
    
    return all_results


def print_summary_statistics(all_results, model_names, num_inverters):
    """Print summary statistics - only RMSE"""
    print("\n" + "="*80)
    print("TRACKING ERROR COMPARISON (RMSE)")
    print("="*80)
    
    for scenario_name, scenario_results in all_results.items():
        print(f"\nScenario: {scenario_name}")
        print("-" * 80)
        
        # Print header
        print(f"{'Model':<20} {'Delta RMSE (rad)':<20} {'Omega RMSE (rad/s)':<20} {'Xi RMSE':<20} {'Overall RMSE':<20}")
        print("-" * 80)
        
        for model_name in model_names:
            if model_name in scenario_results and scenario_results[model_name] is not None:
                metrics = scenario_results[model_name]['metrics']
                delta_rmse = np.mean(metrics['delta_rmse'])
                omega_rmse = np.mean(metrics['omega_rmse'])
                xi_rmse = np.mean(metrics['xi_rmse'])
                # Overall RMSE: average of omega and xi RMSEs only
                overall_rmse = (omega_rmse + xi_rmse) / 2.0
                
                print(f"{model_name:<20} {delta_rmse:<20.6f} {omega_rmse:<20.6f} {xi_rmse:<20.6f} {overall_rmse:<20.6f}")
        
        # Print RMSE for each inverter
        print("\nRMSE by Inverter:")
        print("-" * 80)
        
        # Print header
        header = f"{'Model':<20}"
        for i in range(num_inverters):
            header += f" {'Inv'+str(i)+' Delta':<15} {'Inv'+str(i)+' Omega':<15} {'Inv'+str(i)+' Xi':<15}"
        print(header)
        print("-" * 80)
        
        # Print data for each model
        for model_name in model_names:
            if model_name in scenario_results and scenario_results[model_name] is not None:
                metrics = scenario_results[model_name]['metrics']
                row = f"{model_name:<20}"
                for i in range(num_inverters):
                    delta_rmse = metrics['delta_rmse'][i]
                    omega_rmse = metrics['omega_rmse'][i]
                    xi_rmse = metrics['xi_rmse'][i]
                    row += f" {delta_rmse:<15.6f} {omega_rmse:<15.6f} {xi_rmse:<15.6f}"
                print(row)


if __name__ == "__main__":
    # Configuration
    num_inverters = 3
    controlled_indices = list(range(1, num_inverters))  # Control all inverters except the first one
    time_steps = 500  # 5 seconds at dt=0.01
    delay_time_step = 10  # 0.1 seconds delay
    
    # Pre-trained model path (not used, uses default paths)
    pre_trained_model = None
    
    # Check if pre-trained models exist
    if not any(os.path.exists(f"pre_train_model/control_model_{i}.pth") for i in range(3)):
        pre_trained_model = None
        print("Warning: Pre-trained models not found. Using random initialization.")
    
    # Run comparison
    results = compare_models(
        num_inverters=num_inverters,
        controlled_indices=controlled_indices,
        time_steps=time_steps,
        delay_time_step=delay_time_step,
        pre_trained_model=pre_trained_model
    )
    
    print("\nComparison completed!")

