
import numpy as np
import torch
from torch_sparse import SparseTensor

def extract_pandapower_matrices(ppci):
    """Extract matrices from a single pandapower network."""
    Ybus = ppci["Ybus"]
    Sbus = ppci["Sbus"]
    V = ppci["V"]
    ref = ppci["ref"]
    pv = ppci["pv"]
    pq = ppci["pq"]
    baseMVA = ppci["baseMVA"]
    
    return Ybus, Sbus, V, ref, pv, pq, baseMVA

def scipy_sparse_to_torch_sparse(scipy_matrix, device='cpu'):
    """Convert scipy sparse matrix to PyTorch SparseTensor."""
    scipy_matrix = scipy_matrix.tocoo()
    indices = torch.from_numpy(np.vstack((scipy_matrix.row, scipy_matrix.col))).long()
    values = torch.from_numpy(scipy_matrix.data).to(torch.complex64)
    size = scipy_matrix.shape
    
    return SparseTensor(row=indices[0], col=indices[1], value=values, 
                       sparse_sizes=size).to(device)

def torch_sparse_mm_complex(sparse_tensor, dense_tensor):
    """Sparse-dense matrix multiplication for complex tensors."""
    dense_real = dense_tensor.real
    dense_imag = dense_tensor.imag
    
    result_real = sparse_tensor @ dense_real
    result_imag = sparse_tensor @ dense_imag
    
    return torch.complex(result_real, result_imag)

class BatchPhysicsInformedLoss(torch.nn.Module):
    """
    Batch-compatible physics-informed loss function for multiple power flow networks.
    """
    
    def __init__(self, device='cpu', default_vm=1.0, default_va=0.0):
        super(BatchPhysicsInformedLoss, self).__init__()
        self.device = device
        self.network_cache = {}  # Cache for network matrices
        self.default_vm = default_vm  # Default voltage magnitude for unmapped buses
        self.default_va = default_va  # Default voltage angle for unmapped buses
        
    def _get_network_matrices(self, net_id, ppci):
        """
        Get or compute network matrices with caching.
        
        Args:
            net_id: Unique identifier for the network (e.g., hash or index)
            net: pandapower network object
        
        Returns:
            Dictionary with network matrices and bus information
        """
        if net_id not in self.network_cache:
            # Extract matrices for this network
            Ybus, Sbus, V_ref, ref, pv, pq, baseMVA = extract_pandapower_matrices(ppci)
            
            # Convert to PyTorch tensors
            Ybus_torch = scipy_sparse_to_torch_sparse(Ybus, self.device)
            Sbus_torch = torch.from_numpy(Sbus).to(torch.complex64).to(self.device)
            
            # Store bus type indices as tensors
            ref_torch = torch.from_numpy(ref).long().to(self.device)
            pv_torch = torch.from_numpy(pv).long().to(self.device)
            pq_torch = torch.from_numpy(pq).long().to(self.device)
            
            self.network_cache[net_id] = {
                'Ybus': Ybus_torch,
                'Sbus': Sbus_torch,
                'ref': ref_torch,
                'pv': pv_torch,
                'pq': pq_torch,
                'baseMVA': baseMVA,
                'V_ref': torch.from_numpy(V_ref).to(torch.complex64).to(self.device)
            }
        
        return self.network_cache[net_id]
    
    def _map_net_predictions_to_ppci(self, net_predictions, network_data):
        """
        Map predictions from net bus order to ppci bus order.
        
        Args:
            net_predictions: torch tensor of shape [num_net_buses, 2]
        
        Returns:
            ppci_predictions: torch tensor of shape [num_ppci_buses, 2]
        """
        # Use the slack bus voltage from reference solution
        slack_bus_voltage = torch.tensor([network_data["V_ref"][0].abs().real, torch.angle(network_data["V_ref"][0]).real * (180.0 / torch.pi)])
        net_predictions[0] = slack_bus_voltage  # Slack bus

        num_predictions = len(net_predictions)
        num_ppci_buses = len(network_data["Sbus"])
        if num_predictions == num_ppci_buses:
            # Direct mapping if lengths match
            return net_predictions
        
        net_to_ppci = {}
        # Simple case: assume first len(net.bus) buses in ppci correspond to net buses
        # This is often true but not guaranteed
        for net_idx in range(num_predictions):
            if net_idx < num_ppci_buses:
                net_to_ppci[net_idx] = net_idx

        # Initialize ppci predictions with default values
        ppci_predictions = torch.full((num_ppci_buses, 2), 
                                    fill_value=self.default_vm, 
                                    device=self.device, 
                                    dtype=net_predictions.dtype)
        ppci_predictions[:, 1] = self.default_va  # Set default angle
        
        # Map net predictions to ppci positions
        for net_idx, ppci_idx in net_to_ppci.items():
            if net_idx < len(net_predictions) and ppci_idx < num_ppci_buses:
                ppci_predictions[ppci_idx] = net_predictions[net_idx]
        
        # For unmapped ppci buses, use reference solution or reasonable defaults
        for ppci_idx in range(num_ppci_buses):
            if ppci_idx not in net_to_ppci.values():
                # Use reference solution for internal buses
                ppci_predictions[ppci_idx, 0] = torch.abs(network_data["V_ref"][ppci_idx]).real
                ppci_predictions[ppci_idx, 1] = torch.angle(network_data["V_ref"][ppci_idx]).real * (180.0 / torch.pi)

        return ppci_predictions
    
    def _calculate_single_network_loss(self, predictions, network_data):
        """
        Calculate physics loss for a single network.
        
        Args:
            predictions: torch tensor of shape [num_nodes, 2] for this network
            network_data: Dictionary with network matrices and bus info
        
        Returns:
            loss: Scalar loss for this network
        """
        # Map net predictions to ppci order
        ppci_predictions = self._map_net_predictions_to_ppci(predictions, network_data)

        # Convert predictions to complex voltage vector
        vm_pu = ppci_predictions[:, 0]
        va_degree = ppci_predictions[:, 1]
        va_rad = va_degree * (torch.pi / 180.0)
        
        V = vm_pu * torch.exp(1j * va_rad)

        # Calculate power mismatch
        Ybus_V = torch_sparse_mm_complex(network_data['Ybus'], V.unsqueeze(1)).squeeze(1)
        V_conj_Ybus_V = V * torch.conj(Ybus_V)
        mis = V_conj_Ybus_V - network_data['Sbus']
        
        F = torch.cat([mis.real, mis.imag])
        return torch.sum(F**2)
    
    def forward(self, batch_predictions, batch_data):
        """
        Calculate physics-informed loss for a batch of networks.
        
        Args:
            batch_predictions: List of prediction tensors, each of shape [num_nodes_i, 2]
            batch_data: List of PyTorch Geometric Data objects, each with data.ppci
        
        Returns:
            loss: Average physics-based loss across the batch
        """
        batch_losses = []
        
        for i, (predictions, data) in enumerate(zip(batch_predictions, batch_data)):
            # Create unique identifier for this network (you might want a better hash)
            net_id = f"net_{i}_{hash(str(data.ppci))}"
            
            # Get network matrices (cached if available)
            network_data = self._get_network_matrices(net_id, data.ppci)
            
            # Calculate loss for this network
            loss = self._calculate_single_network_loss(predictions, network_data)
            batch_losses.append(loss)
        
        # Return average loss across batch
        return torch.stack(batch_losses).mean()
    
    def verify_batch(self, batch_data):
        """
        Verify the loss function with pandapower reference solutions for a batch.
        
        Args:
            batch_data: List of PyTorch Geometric Data objects
        
        Returns:
            List of verification results for each network in the batch
        """
        results = []
        
        for i, data in enumerate(batch_data):
            net_id = f"net_{i}_{hash(str(data.ppci))}"
            network_data = self._get_network_matrices(net_id, data.ppci)

            with torch.no_grad():
                # Convert reference solution to expected format
                vm_ref = torch.abs(network_data['V_ref'])
                va_ref = torch.angle(network_data['V_ref']) * (180.0 / torch.pi)
                ref_predictions = torch.stack([vm_ref, va_ref], dim=1)
                
                # Calculate loss with reference solution
                ref_loss = self._calculate_single_network_loss(ref_predictions, network_data)
                
                results.append({
                    'network_id': net_id,
                    'reference_loss': ref_loss.item(),
                    'num_buses': len(network_data['V_ref']),
                    'num_pv_buses': len(network_data['pv']),
                    'num_pq_buses': len(network_data['pq'])
                })
        print("Batch reference loss:", np.mean([result['reference_loss'] for result in results]).item())
        
        return results

# Factory function for batch processing
def create_batch_physics_loss(device='cpu'):
    """
    Create a batch-compatible physics-informed loss function.
    
    Args:
        device: PyTorch device ('cpu' or 'cuda')
    
    Returns:
        loss_fn: BatchPhysicsInformedLoss instance
    """
    return BatchPhysicsInformedLoss(device)
