import torch
import torch.nn as nn
import numpy as np

class MultiViewAlignmentLoss(nn.Module):
    """
    Multi-view alignment loss function
    Uses L1 loss to align functional hidden states (hf) of equivalent node pairs from AIG to other views
    Supports using equivalent label data from parse_pair.py
    """
    
    def __init__(self, loss_weight=1.0):
        super(MultiViewAlignmentLoss, self).__init__()
        self.loss_weight = loss_weight
        self.l1_loss = nn.L1Loss(reduction='mean')
        
    def forward(self, graph, hf_dict):
        """
        Calculate multi-view alignment loss
        
        Args:
            graph: Graph object containing equivalent pair information, including the following equivalent labels:
                - graph.aig_mig_equ: AIG to MIG equivalent pairs (AIG node indices)
                - graph.mig_aig_equ: MIG to AIG equivalent pairs (MIG node indices)
                - graph.aig_xmg_equ: AIG to XMG equivalent pairs (AIG node indices)
                - graph.xmg_aig_equ: XMG to AIG equivalent pairs (XMG node indices)
                - graph.aig_xag_equ: AIG to XAG equivalent pairs (AIG node indices)
                - graph.xag_aig_equ: XAG to AIG equivalent pairs (XAG node indices)
            hf_dict: Dictionary containing functional hidden states for each view
                {
                    'aig': aig_hf,  # shape: [num_aig_nodes, hidden_dim]
                    'xmg': xmg_hf,   # shape: [num_xmg_nodes, hidden_dim] 
                    'xag': xag_hf,   # shape: [num_xag_nodes, hidden_dim]
                    'mig': mig_hf    # shape: [num_mig_nodes, hidden_dim]
                }
        
        Returns:
            total_loss: Total alignment loss
            loss_dict: Alignment loss details for each view pair
        """
        total_loss = 0.0
        loss_dict = {}
        
        # Define view pair mappings (AIG node indices, other view node indices)
        view_pairs = [
            ('aig', 'mig', 'aig_mig_equ', 'mig_aig_equ'),
            ('aig', 'xmg', 'aig_xmg_equ', 'xmg_aig_equ'), 
            ('aig', 'xag', 'aig_xag_equ', 'xag_aig_equ')
        ]
        
        # Add debug information
        debug_info = {}
        
        for view1, view2, equ_key1, equ_key2 in view_pairs:
            # Get equivalent pair data
            equivalent_indices1 = getattr(graph, equ_key1, None)  # AIG node indices
            equivalent_indices2 = getattr(graph, equ_key2, None)  # Other view node indices
            
            # Debug information
            debug_info[f'{equ_key1}_exists'] = equivalent_indices1 is not None
            debug_info[f'{equ_key2}_exists'] = equivalent_indices2 is not None
            if equivalent_indices1 is not None:
                debug_info[f'{equ_key1}_len'] = len(equivalent_indices1)
                debug_info[f'{equ_key1}_type'] = type(equivalent_indices1).__name__
                debug_info[f'{equ_key1}_sample'] = equivalent_indices1[:3] if len(equivalent_indices1) > 0 else []
            if equivalent_indices2 is not None:
                debug_info[f'{equ_key2}_len'] = len(equivalent_indices2)
                debug_info[f'{equ_key2}_type'] = type(equivalent_indices2).__name__
                debug_info[f'{equ_key2}_sample'] = equivalent_indices2[:3] if len(equivalent_indices2) > 0 else []
            
            if equivalent_indices1 is not None and equivalent_indices2 is not None and \
               len(equivalent_indices1) > 0 and len(equivalent_indices2) > 0:
                
                # Get corresponding functional hidden states
                hf1 = hf_dict.get(view1)  # AIG hf
                hf2 = hf_dict.get(view2)  # Other view hf
                
                debug_info[f'{view1}_hf_exists'] = hf1 is not None
                debug_info[f'{view2}_hf_exists'] = hf2 is not None
                if hf1 is not None:
                    debug_info[f'{view1}_hf_shape'] = hf1.shape
                if hf2 is not None:
                    debug_info[f'{view2}_hf_shape'] = hf2.shape
                
                if hf1 is not None and hf2 is not None:
                    try:
                        # Ensure equivalent pair indices are in tensor format and handle data inconsistency issues
                        if not isinstance(equivalent_indices1, torch.Tensor):
                            # Check data consistency
                            if isinstance(equivalent_indices1, (list, tuple)):
                                # Handle numpy array list cases
                                if len(equivalent_indices1) > 0 and hasattr(equivalent_indices1[0], '__len__'):
                                    # If it's a numpy array list, take the first array
                                    equivalent_indices1 = equivalent_indices1[0]
                                # Ensure all elements are scalars
                                equivalent_indices1 = [int(idx) if isinstance(idx, (int, float, np.integer)) else idx for idx in equivalent_indices1]
                                # Filter out invalid values
                                equivalent_indices1 = [idx for idx in equivalent_indices1 if isinstance(idx, int) and idx >= 0]
                            equivalent_indices1 = torch.tensor(equivalent_indices1, dtype=torch.long, device=hf1.device)
                        
                        if not isinstance(equivalent_indices2, torch.Tensor):
                            # Check data consistency
                            if isinstance(equivalent_indices2, (list, tuple)):
                                # Handle numpy array list cases
                                if len(equivalent_indices2) > 0 and hasattr(equivalent_indices2[0], '__len__'):
                                    # If it's a numpy array list, take the first array
                                    equivalent_indices2 = equivalent_indices2[0]
                                # Ensure all elements are scalars
                                equivalent_indices2 = [int(idx) if isinstance(idx, (int, float, np.integer)) else idx for idx in equivalent_indices2]
                                # Filter out invalid values
                                equivalent_indices2 = [idx for idx in equivalent_indices2 if isinstance(idx, int) and idx >= 0]
                            equivalent_indices2 = torch.tensor(equivalent_indices2, dtype=torch.long, device=hf2.device)
                        
                        debug_info[f'{equ_key1}_tensor_shape'] = equivalent_indices1.shape
                        debug_info[f'{equ_key2}_tensor_shape'] = equivalent_indices2.shape
                        debug_info[f'{equ_key1}_tensor_sample'] = equivalent_indices1[:3].tolist()
                        debug_info[f'{equ_key2}_tensor_sample'] = equivalent_indices2[:3].tolist()
                        
                        # Ensure both index lists have consistent length
                        min_len = min(len(equivalent_indices1), len(equivalent_indices2))
                        if min_len == 0:
                            loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = 0.0
                            loss_dict[f'{view1}_to_{view2}_num_pairs'] = 0
                            continue
                        
                        equivalent_indices1 = equivalent_indices1[:min_len]
                        equivalent_indices2 = equivalent_indices2[:min_len]
                        
                        # Ensure indices are within valid range
                        valid_indices1 = (equivalent_indices1 < len(hf1)).all()
                        valid_indices2 = (equivalent_indices2 < len(hf2)).all()
                        
                        debug_info[f'{equ_key1}_valid'] = valid_indices1.item()
                        debug_info[f'{equ_key2}_valid'] = valid_indices2.item()
                        debug_info[f'{equ_key1}_max_idx'] = equivalent_indices1.max().item()
                        debug_info[f'{equ_key2}_max_idx'] = equivalent_indices2.max().item()
                        debug_info[f'{view1}_hf_len'] = len(hf1)
                        debug_info[f'{view2}_hf_len'] = len(hf2)
                        
                        if valid_indices1 and valid_indices2:
                            # Get functional hidden states of equivalent nodes
                            hf1_equivalent = hf1[equivalent_indices1]  # shape: [num_pairs, hidden_dim]
                            hf2_equivalent = hf2[equivalent_indices2]  # shape: [num_pairs, hidden_dim]
                            
                            debug_info[f'{view1}_hf_equiv_shape'] = hf1_equivalent.shape
                            debug_info[f'{view2}_hf_equiv_shape'] = hf2_equivalent.shape
                            
                            # Calculate functional hidden state alignment loss (L1 loss)
                            hf_alignment_loss = self.l1_loss(hf1_equivalent, hf2_equivalent)
                            
                            debug_info[f'{view1}_to_{view2}_loss'] = hf_alignment_loss.item()
                            
                            total_loss += hf_alignment_loss
                            loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = hf_alignment_loss.item()
                            loss_dict[f'{view1}_to_{view2}_num_pairs'] = len(equivalent_indices1)
                        else:
                            print(f"Warning: Invalid indices in {equ_key1} or {equ_key2}, skipping this pair")
                            loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = 0.0
                            loss_dict[f'{view1}_to_{view2}_num_pairs'] = 0
                    except Exception as e:
                        print(f"Error processing {equ_key1}/{equ_key2}: {e}")
                        debug_info[f'{view1}_to_{view2}_error'] = str(e)
                        loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = 0.0
                        loss_dict[f'{view1}_to_{view2}_num_pairs'] = 0
                else:
                    loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = 0.0
                    loss_dict[f'{view1}_to_{view2}_num_pairs'] = 0
            else:
                loss_dict[f'{view1}_to_{view2}_hf_alignment_loss'] = 0.0
                loss_dict[f'{view1}_to_{view2}_num_pairs'] = 0
        
        # Print debug information (only on first call)
        if not hasattr(self, '_debug_printed'):
            print("Alignment Loss Debug Info:")
            for key, value in debug_info.items():
                print(f"  {key}: {value}")
            self._debug_printed = True
        
        return total_loss * self.loss_weight, loss_dict 