import torch
import torch.nn as nn
import numpy as np
from pre_train_model.learn_dynamics_control import DynamicsNN

class GraphCouplingMatrix(nn.Module):
    def __init__(self, N, G, device='cuda:0'):
        """
        A learnable coupling matrix for microgrid communication graph.
        
        Args:
            N (int): Number of inverters.
            G (torch.Tensor): Initial adjacency matrix.
        """
        super(GraphCouplingMatrix, self).__init__()
        self.N = N
        self.G = G
        self.device = device
        # Define a learnable parameter matrix
        self.coupling_matrix = torch.zeros(N, N).to(device)
        self.reset_parameters()
        self.coupling_matrix = nn.Parameter(self.coupling_matrix, requires_grad=True).to(device)
        
    def reset_parameters(self):
        """Initialize the coupling matrix with small values"""
        nn.init.uniform_(self.coupling_matrix, a=0.01, b=0.1)
        # Zero out non-adjacent connections (for microgrid, usually a line topology)
        self.coupling_matrix = self.coupling_matrix * self.G
        
    def forward(self, G):
        """
        Forward pass to get the masked coupling matrix.
        
        Args:
            G (torch.Tensor): Adjacency matrix of shape (N, N), binary values {0,1}.
        
        Returns:
            torch.Tensor: Masked and nonnegative coupling matrix.
        """
        # Apply ReLU for nonnegativity
        A_tilde = torch.relu(self.coupling_matrix).to(self.device)
        
        # Apply adjacency matrix mask
        A_masked = torch.clamp(A_tilde * G, 0, 2)
        
        # Diagonal dominance adjustment
        row_sum = torch.sum(A_masked, dim=1) - 2*torch.diag(A_masked)  # Sum of non-diagonal elements
        diag_values = row_sum + 1e-5  # Make diagonal elements slightly larger than row sum
        A_diag = torch.diag_embed(diag_values)  # Create diagonal matrix
        
        # Update diagonal elements to ensure diagonal dominance
        A_final = A_masked + A_diag  
        
        return A_final

class VectorLyapunovNetwork(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, G=None, num_inverters=3):
        """
        Vector Lyapunov Function Network for microgrid stability.
        
        Args:
            input_dim (int): Dimension of state (δ, ω, ξ)
            hidden_dim (int): Dimension of hidden layers
            G (torch.Tensor): Communication graph adjacency matrix
        """
        super(VectorLyapunovNetwork, self).__init__()
        
        self.num_inverters = num_inverters
        self.input_dim = input_dim

        # Create networks dynamically for each inverter
        # First and last inverters have 1 neighbor -> input_dim
        # Middle inverters have 2 neighbors -> input_dim + 2
        self.networks = nn.ModuleList()
        for i in range(num_inverters):
            if i == 0 or i == num_inverters - 1:
                # First or last inverter: 1 neighbor
                network_input_dim = input_dim  # e_delta_ij(1) + e_omega_i(1) + e_xi_i(1) = 3
            else:
                # Middle inverters: 2 neighbors
                network_input_dim = input_dim + 2  # e_delta_ij(2) + e_omega_i(1) + e_xi_i(2) = 5
            
            network = nn.Sequential(
                nn.Linear(network_input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
            self.networks.append(network)
        
        # For backward compatibility with num_inverters=3
        if num_inverters == 3:
            self.network_1 = self.networks[0]
            self.network_2 = self.networks[1]
            self.network_3 = self.networks[2]

        if G is not None:
            self.coupling_matrix = GraphCouplingMatrix(self.num_inverters, G)


    def forward(self, state):
        """
        Compute Lyapunov function value for given state.
        
        Args:
            state (torch.Tensor): State tensor [batch_size, num_inverters*3]
                                Contains phase angles, frequencies, and controller states
        
        Returns:
            torch.Tensor: Lyapunov function value [batch_size, 2]
        """
        # error state -> e_delta_ij = e_delta_i - e_delta_j, e_omega_i = e_omega_i - e_omega_star, e_xi_ij = e_xi_i - e_xi_j
        V = []
        for i in range(self.num_inverters):
            e_delta_ij = []
            e_omega_i = state[:, i*3+1:i*3+2]
            e_xi_i = []
            if i-1 >= 0:
                e_delta_ij.append((state[:, i*3:i*3+1] - state[:, (i-1)*3:(i-1)*3+1]))
                e_xi_i.append((state[:, i*3+2:i*3+3] - state[:, (i-1)*3+2:(i-1)*3+3]))
            if i+1 < self.num_inverters:
                e_delta_ij.append((state[:, i*3:i*3+1] - state[:, (i+1)*3:(i+1)*3+1]))
                e_xi_i.append((state[:, i*3+2:i*3+3] - state[:, (i+1)*3+2:(i+1)*3+3]))
            e_delta_ij = torch.cat(e_delta_ij, dim=-1)
            e_xi_i = torch.cat(e_xi_i, dim=-1)  
            #e_omega_i = e_omega_i.unsqueeze(1)
            #print(e_delta_ij.shape, e_omega_i.shape, e_xi_i.shape)
            error_state = torch.cat([e_delta_ij, e_omega_i, e_xi_i], dim=-1)
            equilibrium_state = torch.zeros_like(error_state)
            V_i = self.networks[i](error_state) - self.networks[i](equilibrium_state) + 0.001
            V.append(V_i)
        V = torch.cat(V, dim=-1)
        return V

class ControllerNN(nn.Module):
    """Neural network for learning microgrid controller"""
    def __init__(self, input_dim, output_dim=1, hidden_dim=64):
        """
        Controller network for microgrid.
        
        Args:
            input_dim (int): Input dimension (state + neighbor states + target state)
            output_dim (int): Output dimension (control input)
            hidden_dim (int): Hidden layer dimension
        """
        super(ControllerNN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def forward(self, state):
        """
        Forward pass.
        
        Args:
            state (torch.Tensor): Input state [batch_size, input_dim]
        
        Returns:
            torch.Tensor: Control input [batch_size, output_dim]
        """
        return self.net(state)

class CombinedController(nn.Module):
    def __init__(self, input_dim, output_dim=1, hidden_dim=64, num_inverters=3):
        """
        Combined controller for multiple inverters.
        
        Args:
            input_dim (int): Input dimension per inverter
            output_dim (int): Output dimension per inverter
            hidden_dim (int): Hidden layer dimension
            num_inverters (int): Number of inverters
        """
        super(CombinedController, self).__init__()
        self.num_inverters = num_inverters
        self.input_dim = input_dim
        
        # Create controllers for each inverter
        self.controllers = nn.ModuleList()
        for i in range(num_inverters):
            if i == 0 or i == num_inverters - 1:
                # First or last inverter: 1 neighbor
                neighbor_dim = input_dim * 1
            else:
                # Middle inverters: 2 neighbors
                neighbor_dim = input_dim * 2
            self.controllers.append(ControllerNN(input_dim + neighbor_dim, output_dim, hidden_dim))
        
        # For backward compatibility with num_inverters=3
        if num_inverters == 3:
            self.controller_1 = self.controllers[0]
            self.controller_2 = self.controllers[1]
            self.controller_3 = self.controllers[2]
        
    def forward(self, state):
        """
        Forward pass.
        
        Args:
            state (torch.Tensor): Input state [batch_size, num_inverters*input_dim]
        
        Returns:
            torch.Tensor: Control inputs [batch_size, num_inverters*output_dim]
        """
        batch_size = state.shape[0]
        controls = []
        
        for i in range(self.num_inverters):
            # Extract ego state for inverter i
            state_ego = state[:, i*self.input_dim:(i+1)*self.input_dim]
            
            # Get neighbor states
            neighbor_states = []
            if i > 0:
                neighbor_states.append(state[:, (i-1)*self.input_dim:i*self.input_dim])
            if i < self.num_inverters - 1:
                neighbor_states.append(state[:, (i+1)*self.input_dim:(i+2)*self.input_dim])
            
            # Concatenate ego and neighbor states
            if neighbor_states:
                state_input = torch.cat([state_ego] + neighbor_states, dim=-1)
            else:
                state_input = state_ego
            
            # Get control input
            control = self.controllers[i](state_input)
            controls.append(control)
        
        return torch.cat(controls, dim=-1) 
    

class CombinedSystemDynamics(nn.Module):
    def __init__(self, state_dim, neighbor_dim, control_dim, hidden_dim=64, num_inverters=3):
        super(CombinedSystemDynamics, self).__init__()
        self.num_inverters = num_inverters
        self.state_dim = state_dim
        
        # Create dynamics for each inverter
        self.dynamics = nn.ModuleList()
        for i in range(num_inverters):
            if i == 0 or i == num_inverters - 1:
                # First or last inverter: 1 neighbor
                actual_neighbor_dim = neighbor_dim
            else:
                # Middle inverters: 2 neighbors
                actual_neighbor_dim = neighbor_dim * 2
            self.dynamics.append(DynamicsNN(state_dim, actual_neighbor_dim, control_dim, hidden_dim))

        # Create control selection matrices
        for i in range(num_inverters):
            control_sel = torch.zeros(num_inverters, 1)
            control_sel[i, 0] = 1
            self.register_buffer(f'control_selection_{i}', control_sel)
        
        # For backward compatibility with num_inverters=3
        if num_inverters == 3:
            self.dynamics_1 = self.dynamics[0]
            self.dynamics_2 = self.dynamics[1]
            self.dynamics_3 = self.dynamics[2]
            self.control_selection_1 = getattr(self, 'control_selection_0')
            self.control_selection_2 = getattr(self, 'control_selection_1')
            self.control_selection_3 = getattr(self, 'control_selection_2')

    def forward(self, state, control):
        """
        Forward pass.
        
        Args:
            state (torch.Tensor): Input state [batch_size, num_inverters*state_dim]
            control (torch.Tensor): Control input [batch_size, num_inverters]
        
        Returns:
            torch.Tensor: Next states [batch_size, num_inverters*state_dim]
        """
        batch_size = state.shape[0]
        next_states = []
        
        for i in range(self.num_inverters):
            # Extract ego state for inverter i
            state_ego = state[:, i*self.state_dim:(i+1)*self.state_dim]
            
            # Get control input for inverter i
            control_sel = getattr(self, f'control_selection_{i}')
            control_i = control @ control_sel
            
            # Get neighbor states
            neighbor_states = []
            if i > 0:
                neighbor_states.append(state[:, (i-1)*self.state_dim:i*self.state_dim])
            if i < self.num_inverters - 1:
                neighbor_states.append(state[:, (i+1)*self.state_dim:(i+2)*self.state_dim])
            
            # Concatenate neighbor states
            if neighbor_states:
                state_neighbor = torch.cat(neighbor_states, dim=-1)
            else:
                state_neighbor = torch.zeros(batch_size, 0, device=state.device, dtype=state.dtype)
            
            # Compute next state
            next_state = self.dynamics[i](state_ego, state_neighbor, control_i)
            next_states.append(next_state)
        
        return torch.cat(next_states, dim=-1)