"""
Predictor Feedback Controller for Delayed Microgrid Control Systems

Based on the rollout predictor approach for nonlinear delayed-input systems:
    x_{k+1} = f(x_k, u_{k-d}), d ∈ Z_{>0}

Define the d-step rollout predictor using the known past input buffer {u_{k-d}, ..., u_{k-1}}:
    x̂_{k|k} := x_k
    x̂_{k+r+1|k} := f(x̂_{k+r|k}, u_{k-d+r}), r = 0, ..., d-1

And set:
    u_k = κ(x̂_{k+d|k})

This means the control is computed based on the predicted future state.
"""

import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
import os
from networks import CombinedController


class PredictorFeedbackController:
    """
    Predictor Feedback Controller using rollout prediction for Microgrid.
    
    For a system with delay d, this controller:
    1. Stores past control inputs in a buffer
    2. Uses the system dynamics to predict future states d steps ahead
    3. Computes control based on the predicted state
    """
    
    def __init__(self, base_controller, system_dynamics, delay_steps, 
                 num_inverters, controlled_indices, device='cuda:0'):
        """
        Initialize the Predictor Feedback Controller.
        
        Args:
            base_controller: The underlying state feedback controller κ(x) - CombinedController
            system_dynamics: The system dynamics model (MicrogridFormationDynamics)
            delay_steps: Number of delay steps d
            num_inverters: Number of inverters in the microgrid
            controlled_indices: Indices of controlled inverters
            device: torch device
        """
        self.base_controller = base_controller
        self.system = system_dynamics
        self.delay_steps = delay_steps
        self.num_inverters = num_inverters
        self.controlled_indices = controlled_indices
        self.device = device
        self.dim = 3  # State dimension: (delta, omega, xi)
        
        # Control input buffer: stores past d control inputs
        # Buffer[i] = u_{k-d+i} for i = 0, ..., d-1
        self.control_buffer = []
        
        # Reference state (equilibrium)
        self.omega_star = system_dynamics.omega_star
        self.equilibrium = torch.zeros(num_inverters * 3, device=device)
        for i in range(num_inverters):
            self.equilibrium[i * 3 + 1] = self.omega_star  # omega = omega_star
        
    def reset(self):
        """Reset the control buffer at the start of a new simulation."""
        self.control_buffer = []
        
    def _predict_future_state(self, current_state, control_buffer):
        """
        Predict the state d steps into the future using the rollout predictor.
        
        x̂_{k|k} := x_k
        x̂_{k+r+1|k} := f(x̂_{k+r|k}, u_{k-d+r}), r = 0, ..., d-1
        
        Args:
            current_state: Current state x_k [1, num_inverters, 3]
            control_buffer: List of past controls [u_{k-d}, u_{k-d+1}, ..., u_{k-1}]
                           Each entry is a tensor of shape [1, num_inverters]
            
        Returns:
            predicted_state: Predicted state x̂_{k+d|k} [1, num_inverters, 3]
        """
        # x̂_{k|k} := x_k
        predicted_state = current_state.clone()
        
        # Zero disturbance for prediction
        disturbance = torch.zeros(1, self.num_inverters, self.dim, device=self.device)
        
        # Iterate through r = 0, ..., d-1
        for r in range(self.delay_steps):
            # Get control from buffer: u_{k-d+r}
            if r < len(control_buffer):
                control = control_buffer[r]
            else:
                # If buffer is not full, use zero control
                control = torch.zeros(1, self.num_inverters, device=self.device)
            
            # x̂_{k+r+1|k} := f(x̂_{k+r|k}, u_{k-d+r})
            predicted_state = self.system.next_state(predicted_state, control, disturbance, eval=True)
        
        return predicted_state
    
    def _compute_error_state(self, state):
        """
        Compute error state from absolute state.
        
        Args:
            state: Absolute state [1, num_inverters, 3]
            
        Returns:
            error_state: Error state [1, num_inverters * 3]
        """
        batch_size = state.shape[0]
        states_flatten = state.reshape(batch_size, -1)  # [1, num_inverters*3]
        equilibrium_expanded = self.equilibrium.unsqueeze(0).expand(batch_size, -1)  # [1, num_inverters*3]
        error_state = states_flatten - equilibrium_expanded  # [1, num_inverters*3]
        
        # Use relative phase differences: delta absolute value doesn't matter for power flow
        # Only relative phase differences between inverters matter
        # Subtract first inverter's delta error from all to get relative phase differences
        delta_ref = error_state[0, 0].clone()  # Reference delta error (first inverter)
        for i in range(self.num_inverters):
            error_state[0, i * 3] = error_state[0, i * 3] - delta_ref
        
        return error_state
    
    def compute_control(self, current_state, time_step):
        """
        Compute the control input using predictor feedback.
        
        u_k = κ(x̂_{k+d|k})
        
        Args:
            current_state: Current system state [1, num_inverters, 3]
            time_step: Current time step
            
        Returns:
            control: Control tensor [1, num_inverters]
            control_values: numpy array of control values [num_inverters]
        """
        control_values = np.zeros(self.num_inverters)
        
        # For the initial phase (t < delay_steps), we use zero control
        if time_step < self.delay_steps:
            control = torch.zeros(1, self.num_inverters, device=self.device)
            
            # Store the control in buffer
            self.control_buffer.append(control.clone())
            
            return control, control_values
        
        # Predict future state x̂_{k+d|k}
        predicted_state = self._predict_future_state(current_state, self.control_buffer)
        
        # Compute error state from predicted state
        error_state = self._compute_error_state(predicted_state)
        
        # Compute control based on predicted error state: u_k = κ(x̂_{k+d|k})
        control = self.base_controller(error_state)
        
        # Ensure control has correct shape [1, num_inverters]
        if control.dim() == 1:
            control = control.unsqueeze(0)
        
        # Set control to zero for uncontrolled inverters
        for i in range(self.num_inverters):
            if i not in self.controlled_indices:
                control[0, i] = 0.0
        
        # Convert to control values for each inverter
        control_values = control[0].detach().cpu().numpy()
        
        # Update control buffer (sliding window)
        self.control_buffer.append(control.clone())
        
        # Keep only the last d controls
        if len(self.control_buffer) > self.delay_steps:
            self.control_buffer.pop(0)
        
        return control, control_values


def load_pretrained_controller(num_inverters, controlled_indices, device, pre_trained_model=None):
    """
    Load pre-trained controller for Microgrid.
    
    Args:
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        device: torch device
        pre_trained_model: Path to pre-trained controller (not used, uses default paths)
        
    Returns:
        controller: CombinedController
    """
    controller = CombinedController(
        input_dim=3,
        output_dim=1,
        num_inverters=num_inverters
    ).to(device)
    
    # Load pre-trained controller weights
    if num_inverters == 3:
        if os.path.exists("pre_train_model/control_model_0.pth"):
            controller.controller_1.load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_1.pth"):
            controller.controller_2.load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_2.pth"):
            controller.controller_3.load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
    else:
        # For other numbers of inverters
        if os.path.exists("pre_train_model/control_model_0.pth"):
            controller.controllers[0].load_state_dict(torch.load("pre_train_model/control_model_0.pth", map_location=device))
        if num_inverters > 1 and os.path.exists("pre_train_model/control_model_2.pth"):
            controller.controllers[num_inverters - 1].load_state_dict(torch.load("pre_train_model/control_model_2.pth", map_location=device))
        if os.path.exists("pre_train_model/control_model_1.pth"):
            for i in range(1, num_inverters - 1):
                controller.controllers[i].load_state_dict(torch.load("pre_train_model/control_model_1.pth", map_location=device))
    
    return controller


def simulate_with_predictor_feedback(system, controller, num_inverters, controlled_indices,
                                      initial_states, disturbances, time_steps, 
                                      delay_time_step, device):
    """
    Simulate Microgrid system with Predictor Feedback controller.
    
    Args:
        system: MicrogridFormationDynamics system
        controller: Pre-trained CombinedController
        num_inverters: Number of inverters
        controlled_indices: Indices of controlled inverters
        initial_states: Initial states [num_inverters, 3] (delta, omega, xi)
        disturbances: Disturbances for each time step [time_steps, num_inverters, 3]
        time_steps: Number of simulation steps
        delay_time_step: Delay time steps
        device: torch device
        
    Returns:
        trajectory: [time_steps+1, num_inverters, 3]
        controls: [time_steps, num_inverters]
    """
    # Initialize predictor feedback controller
    pf_controller = PredictorFeedbackController(
        base_controller=controller,
        system_dynamics=system,
        delay_steps=delay_time_step,
        num_inverters=num_inverters,
        controlled_indices=controlled_indices,
        device=device
    )
    pf_controller.reset()
    
    # Initialize trajectory
    if isinstance(initial_states, np.ndarray):
        initial_states = torch.tensor(initial_states, dtype=torch.float32, device=device)
    else:
        initial_states = initial_states.to(device)
    
    current_states = initial_states.unsqueeze(0)  # [1, num_inverters, 3]
    trajectory = [current_states.squeeze(0).detach().cpu().numpy()]
    controls_history = []
    
    dim = 3  # (delta, omega, xi)
    
    # Simulate
    with torch.no_grad():
        for t in range(time_steps):
            # Compute control using predictor feedback
            control, control_values = pf_controller.compute_control(current_states, t)
            
            # Get disturbance for this time step
            if disturbances is not None and t < len(disturbances):
                dist = torch.tensor([disturbances[t]], dtype=torch.float32, device=device)
            else:
                dist = torch.zeros(1, num_inverters, dim, device=device)
            
            # Update states
            current_states = system.next_state(current_states, control, dist, eval=True)
            trajectory.append(current_states.squeeze(0).detach().cpu().numpy())
            controls_history.append(control_values)
    
    return np.array(trajectory), np.array(controls_history)


# Alternative: Truncation-based Predictor Feedback Controller
# This is a simplified version that uses truncated dynamics
class TruncatedPredictorController:
    """
    Truncated Predictor Feedback Controller for Microgrid.
    
    Instead of using full rollout prediction, this uses a truncated/simplified
    predictor that accounts for delay through state extrapolation.
    """
    
    def __init__(self, base_controller, system_dynamics, delay_steps, 
                 num_inverters, controlled_indices, device='cuda:0'):
        """
        Initialize the Truncated Predictor Controller.
        """
        self.base_controller = base_controller
        self.system = system_dynamics
        self.delay_steps = delay_steps
        self.num_inverters = num_inverters
        self.controlled_indices = controlled_indices
        self.device = device
        self.dim = 3
        
        # State history buffer
        self.state_history = []
        
        # Reference state
        self.omega_star = system_dynamics.omega_star
        self.equilibrium = torch.zeros(num_inverters * 3, device=device)
        for i in range(num_inverters):
            self.equilibrium[i * 3 + 1] = self.omega_star
        
    def reset(self):
        """Reset the state history buffer."""
        self.state_history = []
        
    def _extrapolate_state(self, current_state):
        """
        Extrapolate state forward using recent state history.
        
        Uses simple linear extrapolation based on recent state changes.
        """
        if len(self.state_history) < 2:
            return current_state
        
        # Get recent states
        prev_state = self.state_history[-1]
        
        # Compute rate of change
        state_diff = current_state - prev_state
        
        # Extrapolate forward by delay_steps
        extrapolated = current_state + state_diff * self.delay_steps
        
        return extrapolated
    
    def _compute_error_state(self, state):
        """Compute error state from absolute state."""
        batch_size = state.shape[0]
        states_flatten = state.reshape(batch_size, -1)
        equilibrium_expanded = self.equilibrium.unsqueeze(0).expand(batch_size, -1)
        error_state = states_flatten - equilibrium_expanded
        
        # Use relative phase differences
        delta_ref = error_state[0, 0].clone()
        for i in range(self.num_inverters):
            error_state[0, i * 3] = error_state[0, i * 3] - delta_ref
        
        return error_state
    
    def compute_control(self, current_state, time_step):
        """
        Compute control using truncated predictor.
        """
        control_values = np.zeros(self.num_inverters)
        
        # Store current state
        self.state_history.append(current_state.clone())
        if len(self.state_history) > self.delay_steps + 2:
            self.state_history.pop(0)
        
        # For initial phase, use zero control
        if time_step < self.delay_steps:
            control = torch.zeros(1, self.num_inverters, device=self.device)
            return control, control_values
        
        # Extrapolate state
        predicted_state = self._extrapolate_state(current_state)
        
        # Compute error state
        error_state = self._compute_error_state(predicted_state)
        
        # Compute control based on extrapolated state
        control = self.base_controller(error_state)
        
        # Ensure control has correct shape
        if control.dim() == 1:
            control = control.unsqueeze(0)
        
        # Set control to zero for uncontrolled inverters
        for i in range(self.num_inverters):
            if i not in self.controlled_indices:
                control[0, i] = 0.0
        
        control_values = control[0].detach().cpu().numpy()
        
        return control, control_values


def simulate_with_truncated_predictor(system, controller, num_inverters, controlled_indices,
                                       initial_states, disturbances, time_steps, 
                                       delay_time_step, device):
    """
    Simulate Microgrid system with Truncated Predictor controller.
    """
    # Initialize truncated predictor controller
    tp_controller = TruncatedPredictorController(
        base_controller=controller,
        system_dynamics=system,
        delay_steps=delay_time_step,
        num_inverters=num_inverters,
        controlled_indices=controlled_indices,
        device=device
    )
    tp_controller.reset()
    
    # Initialize trajectory
    if isinstance(initial_states, np.ndarray):
        initial_states = torch.tensor(initial_states, dtype=torch.float32, device=device)
    else:
        initial_states = initial_states.to(device)
    
    current_states = initial_states.unsqueeze(0)
    trajectory = [current_states.squeeze(0).detach().cpu().numpy()]
    controls_history = []
    
    dim = 3
    
    # Simulate
    with torch.no_grad():
        for t in range(time_steps):
            # Compute control using truncated predictor
            control, control_values = tp_controller.compute_control(current_states, t)
            
            # Get disturbance
            if disturbances is not None and t < len(disturbances):
                dist = torch.tensor([disturbances[t]], dtype=torch.float32, device=device)
            else:
                dist = torch.zeros(1, num_inverters, dim, device=device)
            
            # Update states
            current_states = system.next_state(current_states, control, dist, eval=True)
            trajectory.append(current_states.squeeze(0).detach().cpu().numpy())
            controls_history.append(control_values)
    
    return np.array(trajectory), np.array(controls_history)


if __name__ == "__main__":
    from training_exp_comb import MicrogridFormationDynamics
    
    # Test configuration
    num_inverters = 3
    controlled_indices = list(range(1, num_inverters))
    delay_time_step = 10
    time_steps = 500
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Dynamics parameters
    dynamics_params = {
        'dt': 0.01,
        'n': num_inverters,
        'omega_star': 2 * np.pi * 50,
        'tau': [1.4895] * num_inverters,
        'eta': [6.3509e-4] * num_inverters,
        'k': [4.9481] * num_inverters,
        'V': [325.3] * num_inverters,
        'B': np.array([[0.1 if abs(i - j) == 1 else 0 for j in range(num_inverters)] for i in range(num_inverters)]),
        'P_L': [1260.0] * num_inverters,
        'P_star': [1260.0] * num_inverters
    }
    
    # Create connection matrix and system
    connection_matrix = {}
    for i in range(num_inverters):
        connection_matrix[i] = {}
        connection_matrix[i][i] = 0.01
        if i > 0:
            connection_matrix[i][i-1] = 0.01
        if i < num_inverters - 1:
            connection_matrix[i][i+1] = 0.01
    
    system = MicrogridFormationDynamics(dynamics_params, connection_matrix)
    
    # Check if pre-trained models exist
    if any(os.path.exists(f"pre_train_model/control_model_{i}.pth") for i in range(3)):
        controller = load_pretrained_controller(
            num_inverters, controlled_indices, device
        )
        
        # Test scenario
        omega_star = dynamics_params['omega_star']
        initial_states = np.zeros((num_inverters, 3))
        initial_states[:, 1] = omega_star  # Initialize at nominal frequency
        initial_states[0, 1] = omega_star + 0.5  # Add initial frequency error to first inverter
        
        disturbances = [
            np.zeros((num_inverters, 3))
            for t in range(time_steps)
        ]
        
        # Simulate with predictor feedback
        print("Testing Predictor Feedback Controller...")
        trajectory_pf, controls_pf = simulate_with_predictor_feedback(
            system, controller, num_inverters, controlled_indices,
            initial_states, disturbances, time_steps, delay_time_step, device
        )
        
        # Compute RMSE for frequency
        omega_errors = trajectory_pf[:, :, 1] - omega_star
        omega_rmse = np.sqrt(np.mean(omega_errors**2))
        
        print(f"Predictor Feedback - Omega RMSE: {omega_rmse:.6f}")
        
        # Test truncated predictor
        print("\nTesting Truncated Predictor Controller...")
        trajectory_tp, controls_tp = simulate_with_truncated_predictor(
            system, controller, num_inverters, controlled_indices,
            initial_states, disturbances, time_steps, delay_time_step, device
        )
        
        omega_errors_tp = trajectory_tp[:, :, 1] - omega_star
        omega_rmse_tp = np.sqrt(np.mean(omega_errors_tp**2))
        
        print(f"Truncated Predictor - Omega RMSE: {omega_rmse_tp:.6f}")
    else:
        print("Pre-trained models not found in pre_train_model/")
