import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig
import loralib as lora
from typing import Dict, List, Optional, Union, Tuple, Any
import logging
import os
import json
from functools import reduce

logger = logging.getLogger(__name__)

def get_module_by_name(model, module_name):
    """
    Get a module from a model by its name.
    
    Args:
        model: The model to search in
        module_name: The name of the module to find
        
    Returns:
        The module corresponding to the name
    """
    names = module_name.split('.')
    return reduce(getattr, names, model)

def set_module_by_name(model, module_name, module):
    """
    Set a module in a model by its name.
    
    Args:
        model: The model to modify
        module_name: The name of the module to replace
        module: The new module
    """
    names = module_name.split('.')
    parent_module = reduce(getattr, names[:-1], model)
    setattr(parent_module, names[-1], module)
    
def replace_linear_with_lora(
    model: nn.Module,
    r_config: Dict[str, int],
    alpha_config: Optional[Dict[str, int]] = None,
    dropout: float = 0.0,
    merge_weights: bool = False
) -> nn.Module:
    """
    Replace Linear layers in the model with LoRA layers based on the r configuration.
    
    This function performs targeted replacement of model's Linear layers with 
    Low-Rank Adaptation (LoRA) layers. LoRA is a parameter-efficient fine-tuning
    method that adds low-rank decomposition matrices to pre-trained weights.
    
    For each target layer, the function:
    1. Determines the appropriate rank (r) and scaling factor (alpha)
    2. Creates a LoRA layer with the same dimensions as the original layer
    3. Copies the original weights and biases to the LoRA layer
    4. Replaces the original module with the new LoRA module
    5. Records details about the replacement for tracking and evaluation
    
    Args:
        model: The PyTorch model to modify
        r_config: Dictionary mapping layer name to rank (r) value
        alpha_config: Dictionary mapping layer name to alpha value (defaults to 2*r if not provided)
        dropout: Dropout rate for LoRA layers (default: 0.0)
        merge_weights: Whether to merge LoRA weights with original weights when setting to eval mode
        
    Returns:
        Tuple containing:
        - Modified model with LoRA layers
        - Dictionary containing details about replaced layers
    
    Note:
        Layers with r=0 in the r_config will be tracked but not modified with LoRA.
    """
    logger.info(f"Replacing linear layers with LoRA layers using provided r configuration...")
    
    from functools import reduce
    
    # Keep track of replaced layers
    replaced_layers = {}
    
    # Iterate through named modules
    for name, module in list(model.named_modules()):
        # Only target modules specified in the r_config
        if name in r_config and isinstance(module, nn.Linear):
            r = r_config[name]
            alpha = alpha_config.get(name, 2 * r) if alpha_config else 2 * r
            
            # Skip if r = 0 - clearer logging and handling
            if r <= 0:
                 # The original layer already exists, so keep it as is and just record in replaced_layers
                logger.info(f"SKIPPING LoRA for {name} (r = {r}): Original layer will be preserved")
                replaced_layers[name] = {
                    'r': 0,
                    'alpha': 0,
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                    'applied': False,
                    'trainable_params': 0,  # Since r=0, the number of trainable parameters is 0
                    'message': 'LoRA NOT applied (r=0)'
                }
                # Verify that all parameters in the layer are not set to be trained
                for param_name, param in module.named_parameters():
                    if param.requires_grad:
                        logger.debug(f"  - Parameter {param_name} in r=0 layer will not be trained")
                continue
                
            logger.info(f"Replacing {name} with LoRA (r={r}, alpha={alpha})")
            
            # Pre-save necessary module attributes
            in_features = module.in_features
            out_features = module.out_features
            has_bias = module.bias is not None
            weight_data = module.weight.data.clone()
            bias_data = module.bias.data.clone() if has_bias else None
            
            # Create LoRA layer
            lora_layer = lora.Linear(
                in_features=in_features,
                out_features=out_features,
                r=r,
                lora_alpha=alpha,
                lora_dropout=dropout,
                bias=has_bias,
                merge_weights=merge_weights
            )
            
            # Copy weights from original layer
            lora_layer.weight.data = weight_data
            if has_bias:
                lora_layer.bias.data = bias_data
            
            # Replace the module more efficiently
            names = name.split('.')
            parent_module = reduce(getattr, names[:-1], model)
            setattr(parent_module, names[-1], lora_layer)
            
            del module
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Record the replacement
            replaced_layers[name] = {
                'r': r,
                'alpha': alpha,
                'in_features': in_features,
                'out_features': out_features,
                'applied': True
            }

    applied_count = sum(1 for layer in replaced_layers.values() if layer['applied'])
    skipped_count = sum(1 for layer in replaced_layers.values() if not layer['applied'])
    zero_r_count = sum(1 for layer in replaced_layers.values() if layer['r'] == 0)

    logger.info(f"LoRA Layer Summary: Replaced {applied_count} layers with LoRA, skipped {skipped_count} layers")
    logger.info(f"  - Layers with r=0 (no training): {zero_r_count}")
    logger.info(f"  - Layers with r>0 (with training): {applied_count}")

    for name, info in replaced_layers.items():
        if info['r'] == 0 and info['applied']:
            logger.error(f"CRITICAL ERROR: Layer {name} has r=0 but LoRA was applied!")

    return model, replaced_layers

class OptimalLoRAModel(nn.Module):
    """
    Wrapper around a pre-trained model with LoRA layers optimized for specific tasks.
    """
    
    def __init__(
        self,
        base_model: Union[str, PreTrainedModel],
        r_config: Dict[str, int],
        alpha_config: Optional[Dict[str, int]] = None,
        dropout: float = 0.0,
        merge_weights: bool = False
    ):
        """
        Initialize the OptimalLoRAModel.
        
        Args:
            base_model: Base model or model name
            r_config: Dict mapping layer name to r value
            alpha_config: Dict mapping layer name to alpha value (defaults to 2*r if not provided)
            dropout: Dropout rate for LoRA layers
            merge_weights: Whether to merge weights after training
        """
        super().__init__()
        
        # Load the base model
        if isinstance(base_model, str):
            self.model = AutoModel.from_pretrained(base_model)
        else:
            self.model = base_model
            
        # Replace linear layers with LoRA layers
        self.model, self.replaced_layers = replace_linear_with_lora(
            self.model, r_config, alpha_config, dropout, merge_weights
        )
        
        # Store configurations
        self.r_config = r_config
        self.alpha_config = alpha_config or {name: 2 * r for name, r in r_config.items()}
        
    def forward(self, *args, **kwargs):
        """Forward pass to the underlying model."""
        return self.model(*args, **kwargs)
    
    def save_lora_config(self, save_dir: str):
        """
        Save the LoRA configuration to a file.
        
        Args:
            save_dir: Directory to save the configuration
        """
        os.makedirs(save_dir, exist_ok=True)
        config = {
            'r_config': self.r_config,
            'alpha_config': self.alpha_config,
            'replaced_layers': self.replaced_layers
        }
        
        with open(os.path.join(save_dir, 'lora_config.json'), 'w') as f:
            json.dump(config, f, indent=2)
        
        logger.info(f"Saved LoRA configuration to {save_dir}")
    
    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str,
        lora_config_path: Optional[str] = None,
        **kwargs
    ):
        """
        Load a model with LoRA layers from a saved configuration.
        
        Args:
            model_name_or_path: Path to the base model
            lora_config_path: Path to the LoRA configuration
            **kwargs: Additional arguments to pass to the model constructor
            
        Returns:
            OptimalLoRAModel instance
        """
        # Load base model
        base_model = AutoModel.from_pretrained(model_name_or_path)
        
        # Load LoRA config if provided
        if lora_config_path and os.path.exists(lora_config_path):
            with open(lora_config_path, 'r') as f:
                config = json.load(f)
            
            r_config = config.get('r_config', {})
            alpha_config = config.get('alpha_config', None)
            
            return cls(
                base_model=base_model,
                r_config=r_config,
                alpha_config=alpha_config,
                **kwargs
            )
        
        # If no LoRA config provided, return a model with no LoRA layers
        return cls(
            base_model=base_model,
            r_config={},
            **kwargs
        )
    
    def mark_only_lora_as_trainable(self):
        """Mark only LoRA parameters as trainable."""
        lora.mark_only_lora_as_trainable(self.model)
        
        # Count trainable parameters
        lora_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.model.parameters())
        
        logger.info(f"Trainable parameters: {lora_params:,} ({lora_params/total_params:.2%} of total)")
        
        return self
    
# Modifications to models/optimal_lora.py

# 1. Update prepare_bert_for_glue function to store initial r_config

def prepare_model_for_glue(
    base_model_name: str,
    r_config: Dict[str, int],
    num_labels: int,
    dropout: float = 0.1,
):
    """
    Prepare a model with optimal LoRA configuration for GLUE tasks.
    
    Args:
        base_model_name: Name of the base model
        r_config: Dict mapping layer name to r value
        num_labels: Number of labels for the classification head
        dropout: Dropout probability
        
    Returns:
        Model with LoRA layers and classification head
    """
    from transformers import AutoModelForSequenceClassification
    
    # Load base model with classification head
    model = AutoModelForSequenceClassification.from_pretrained(
        base_model_name,
        num_labels=num_labels
    )
    
    # Replace linear layers with LoRA layers
    model, replaced_layers = replace_linear_with_lora(
        model, r_config, alpha_config=None, dropout=dropout, merge_weights=False
    )
    
    # Mark only LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(model)
    
    # Count trainable parameters
    lora_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    logger.info(f"Trainable parameters: {lora_params:,} ({lora_params/total_params:.2%} of total)")
    
    # Store initial r_config as attribute for later reference (used in pruning)
    model.initial_r_config = r_config.copy()
    
    return model, replaced_layers