import torch, os
from torch import inf
from torch import nn
import pandas as pd
from loguru import logger

def calculate_layer_ranks(model):
    """Calculate ranks and metrics of linear/conv2d layers and QK attention matrices.
    
    Args:
        model: The PyTorch model
        
    Returns:
        dict: Dictionary containing layer names and their corresponding metrics
    """
    layer_metrics = {}
    
    # For DDP models, get the underlying model
    if hasattr(model, 'module'):
        model = model.module
    
    def get_weight_matrix(layer):
        """Get the weight matrix from a layer in appropriate shape for SVD"""
        if isinstance(layer, nn.Linear):
            return layer.weight.data.detach()
        elif isinstance(layer, nn.Conv2d):
            weight = layer.weight.data.detach()
            return weight.reshape(weight.shape[0], -1)
        return None
    
    def calculate_matrix_metrics(matrix):
        """Calculate ranks at different thresholds and metrics for both raw and normalized singular values"""
        with torch.no_grad():
            # Calculate singular values
            singular_values = torch.linalg.svd(matrix.float(), full_matrices=False)[1]
            
            # Calculate normalized singular values
            normalized_singular_values = singular_values / singular_values[0]
            
            # Thresholds for raw singular values
            raw_thresholds = {
                'rank_3e': 1e3,   # 1000.0
                'rank_2e': 1e2,   # 100.0
                'rank_1e': 1e1,   # 10.0
                'rank_e0': 1e0,   # 1.0
                'rank_e1': 1e-1,  # 0.1
                'rank_e2': 1e-2,  # 0.01
                'rank_e3': 1e-3,  # 0.001
                'rank_e4': 1e-4,
                'rank_e5': 1e-5,
                'rank_e6': 1e-6
            }
            
            # Thresholds for normalized singular values
            norm_thresholds = {
                'rank_7e1': 7e-1,  # 0.7
                'rank_5e1': 5e-1,  # 0.5
                'rank_3e1': 3e-1,  # 0.3
                'rank_1e1': 1e-1,  # 0.1
                'rank_5e2': 5e-2,  # 0.05
                'rank_1e2': 1e-2,  # 0.01
                'rank_1e3': 1e-3   # 0.001
            }
            
            # Calculate ranks for raw values
            ranks = {
                name: torch.sum(singular_values > threshold).item()
                for name, threshold in raw_thresholds.items()
            }
            
            # Calculate ranks for normalized values
            norm_ranks = {
                name: torch.sum(normalized_singular_values > threshold).item()
                for name, threshold in norm_thresholds.items()
            }
            
            # Calculate other metrics
            nonzero_mask = singular_values > 1e-6
            norm_nonzero_mask = normalized_singular_values > 1e-3  # Use 1e-3 for normalized values
            
            metrics = {
                'total_dim': min(matrix.shape),
                'unclears': torch.sum(singular_values[nonzero_mask]).item(),
                'vols': torch.prod(singular_values[nonzero_mask]).item(),
            }
            
            # Add ranks and their ratios
            metrics.update(ranks)
            metrics.update(norm_ranks)
            metrics.update({
                f"{name}_ratio": rank / min(matrix.shape)
                for name, rank in ranks.items()
            })
            metrics.update({
                f"{name}_ratio": rank / min(matrix.shape)
                for name, rank in norm_ranks.items()
            })
            
            return metrics
    
    # Iterate through all named modules
    for name, module in model.named_modules():
        if 'layers.' in name or name.startswith('layers.'):
            # Handle linear and conv2d layers
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                weight_matrix = get_weight_matrix(module)
                if weight_matrix is not None:
                    layer_metrics[name] = calculate_matrix_metrics(weight_matrix)
            
            # Handle QK product when we encounter q_proj
            if 'q_proj' in name:
                layer_num = name.split('.')[2]
                parent_name = '.'.join(name.split('.')[:-1])
                parent_module = model
                for part in parent_name.split('.'):
                    parent_module = getattr(parent_module, part)
                
                # Get Q and K matrices
                q_weight = parent_module.q_proj.weight.data.detach()
                k_weight = parent_module.k_proj.weight.data.detach()
                
                # Calculate QK' product and its metrics
                qk = torch.matmul(q_weight, k_weight.t())
                layer_metrics[f"layer_{layer_num}_qk"] = calculate_matrix_metrics(qk)
    
    return layer_metrics

def log_layer_ranks(layer_metrics, global_step, args, global_rank=0):
    """Log layer metrics to local excel file with separate sheets."""
    if global_rank == 0:
        os.makedirs('log_rank', exist_ok=True)
        raw_excel_path = f'log_rank/{args.swanlab_name}_raw.xlsx'
        norm_excel_path = f'log_rank/{args.swanlab_name}_normalized.xlsx'
        
        # Define sheet names for raw values
        # Define sheet names for raw values
        raw_sheet_names = [
            'rank_3e', 'rank_2e', 'rank_1e', 'rank_e0', 'rank_e1', 
            'rank_e2', 'rank_e3', 'rank_e4', 'rank_e5', 'rank_e6',
            'rank_3e_ratio', 'rank_2e_ratio', 'rank_1e_ratio', 
            'rank_e0_ratio', 'rank_e1_ratio', 'rank_e2_ratio',
            'rank_e3_ratio', 'rank_e4_ratio', 'rank_e5_ratio', 
            'rank_e6_ratio', 'unclears', 'vols'
        ]

        # Define sheet names for normalized values
        norm_sheet_names = [
            'rank_7e1', 'rank_5e1', 'rank_3e1', 'rank_1e1', 
            'rank_5e2', 'rank_1e2', 'rank_1e3',
            'rank_7e1_ratio', 'rank_5e1_ratio', 'rank_3e1_ratio', 
            'rank_1e1_ratio', 'rank_5e2_ratio', 'rank_1e2_ratio', 
            'rank_1e3_ratio', 'unclears', 'vols'
        ]
        
        # Create metrics dictionaries for both raw and normalized values
        raw_metrics = {sheet: {'step': global_step} for sheet in raw_sheet_names}
        norm_metrics = {sheet: {'step': global_step} for sheet in norm_sheet_names}
        
        def simplify_layer_name(name):
            """Simplify layer names to more readable format."""
            if '_qk' in name:  # Handle QK attention product
                return f"layer{name.split('_')[1]}_QK"
            
            parts = name.split('.')
            if len(parts) < 4:  # Not a layer component
                return None
                
            layer_num = parts[2]  # Get layer number
            component = parts[-2]  # Get component name
            
            # Map component names to simplified versions
            component_map = {
                'self_attn': {
                    'q_proj': 'Q',
                    'k_proj': 'K',
                    'v_proj': 'V',
                    'o_proj': 'O'
                },
                'mlp': {
                    'gate_proj': 'MLP1',
                    'up_proj': 'MLP2',
                    'down_proj': 'MLP3'
                }
            }
            
            for module, mappings in component_map.items():
                if module in name:
                    for old_name, new_name in mappings.items():
                        if old_name in name:
                            return f"layer{layer_num}_{new_name}"
            return None
        
        # Process and filter layer names
        for layer_name, layer_info in layer_metrics.items():
            simple_name = simplify_layer_name(layer_name)
            if simple_name:
                # Add raw metrics
                for sheet in raw_sheet_names:
                    raw_metrics[sheet][simple_name] = layer_info[sheet]
                # Add normalized metrics
                for sheet in norm_sheet_names:
                    norm_metrics[sheet][simple_name] = layer_info[sheet]
        
        # Handle file writing for both files
        for excel_path, metrics, sheet_names in [
            (raw_excel_path, raw_metrics, raw_sheet_names),
            (norm_excel_path, norm_metrics, norm_sheet_names)
        ]:
            dfs_new = {sheet: pd.DataFrame([data]) for sheet, data in metrics.items()}
            
            try:
                if os.path.exists(excel_path):
                    dfs_updated = {}
                    for sheet in sheet_names:
                        df_existing = pd.read_excel(excel_path, sheet_name=sheet)
                        dfs_updated[sheet] = pd.concat(
                            [df_existing, dfs_new[sheet]], 
                            ignore_index=True
                        )
                else:
                    dfs_updated = dfs_new
                
                with pd.ExcelWriter(excel_path, engine='openpyxl', mode='w') as writer:
                    for sheet, df in dfs_updated.items():
                        df.to_excel(writer, sheet_name=sheet, index=False)
                        
            except Exception as e:
                with pd.ExcelWriter(excel_path, engine='openpyxl', mode='w') as writer:
                    for sheet, df in dfs_new.items():
                        df.to_excel(writer, sheet_name=sheet, index=False)


def init_rank_logging(args, global_rank=0):
    """Initialize two excel files for logging rank information."""
    if global_rank == 0:
        os.makedirs('log_rank', exist_ok=True)
        
        # Define paths for both files
        raw_excel_path = f'log_rank/{args.swanlab_name}_raw.xlsx'
        norm_excel_path = f'log_rank/{args.swanlab_name}_normalized.xlsx'
        
        # Remove existing files if they exist
        for excel_path in [raw_excel_path, norm_excel_path]:
            if os.path.exists(excel_path):
                os.remove(excel_path)
                logger.info(f"Removed existing rank log file: {excel_path}")
        
        # Create empty sheets for both files
        df_empty = pd.DataFrame(columns=['step'])
        
        # Sheet names for raw values
        raw_sheet_names = [
            'rank_3e', 'rank_2e', 'rank_1e', 'rank_e0', 'rank_e1', 
            'rank_e2', 'rank_e3', 'rank_e4', 'rank_e5', 'rank_e6',
            'rank_3e_ratio', 'rank_2e_ratio', 'rank_1e_ratio', 
            'rank_e0_ratio', 'rank_e1_ratio', 'rank_e2_ratio',
            'rank_e3_ratio', 'rank_e4_ratio', 'rank_e5_ratio', 
            'rank_e6_ratio', 'unclears', 'vols'
        ]
        
        # Sheet names for normalized values
        norm_sheet_names = [
            'norm_rank_3e', 'norm_rank_2e', 'norm_rank_1e', 'norm_rank_e0', 'norm_rank_e1', 
            'norm_rank_e2', 'norm_rank_e3', 'norm_rank_e4', 'norm_rank_e5', 'norm_rank_e6',
            'norm_rank_3e_ratio', 'norm_rank_2e_ratio', 'norm_rank_1e_ratio', 
            'norm_rank_e0_ratio', 'norm_rank_e1_ratio', 'norm_rank_e2_ratio',
            'norm_rank_e3_ratio', 'norm_rank_e4_ratio', 'norm_rank_e5_ratio', 
            'norm_rank_e6_ratio', 'norm_unclears', 'norm_vols'
        ]
        
        # Initialize both files
        for excel_path, sheet_names in [
            (raw_excel_path, raw_sheet_names),
            (norm_excel_path, norm_sheet_names)
        ]:
            with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
                for sheet in sheet_names:
                    df_empty.to_excel(writer, sheet_name=sheet, index=False)
            logger.info(f"Created new rank log file: {excel_path}")