import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig
import loralib as lora
from typing import Dict, List, Optional, Union, Tuple, Any
import logging
import os
import json
from functools import reduce

logger = logging.getLogger(__name__)

def get_module_by_name(model, module_name):
    """
    Get a module from a model by its name.
    
    Args:
        model: The model to search in
        module_name: The name of the module to find
        
    Returns:
        The module corresponding to the name
    """
    names = module_name.split('.')
    return reduce(getattr, names, model)

def set_module_by_name(model, module_name, module):
    """
    Set a module in a model by its name.
    
    Args:
        model: The model to modify
        module_name: The name of the module to replace
        module: The new module
    """
    names = module_name.split('.')
    parent_module = reduce(getattr, names[:-1], model)
    setattr(parent_module, names[-1], module)
    
def replace_linear_with_lora(
    model: nn.Module,
    r_config: Dict[str, int],
    alpha_config: Optional[Dict[str, int]] = None,
    dropout: float = 0.0,
    merge_weights: bool = False
) -> nn.Module:
    """
    Replace Linear layers in the model with LoRA layers based on the r configuration.
    
    This function performs targeted replacement of model's Linear layers with 
    Low-Rank Adaptation (LoRA) layers. LoRA is a parameter-efficient fine-tuning
    method that adds low-rank decomposition matrices to pre-trained weights.
    
    For each target layer, the function:
    1. Determines the appropriate rank (r) and scaling factor (alpha)
    2. Creates a LoRA layer with the same dimensions as the original layer
    3. Copies the original weights and biases to the LoRA layer
    4. Replaces the original module with the new LoRA module
    5. Records details about the replacement for tracking and evaluation
    
    Args:
        model: The PyTorch model to modify
        r_config: Dictionary mapping layer name to rank (r) value
        alpha_config: Dictionary mapping layer name to alpha value (defaults to 2*r if not provided)
        dropout: Dropout rate for LoRA layers (default: 0.0)
        merge_weights: Whether to merge LoRA weights with original weights when setting to eval mode
        
    Returns:
        Tuple containing:
        - Modified model with LoRA layers
        - Dictionary containing details about replaced layers
    
    Note:
        Layers with r=0 in the r_config will be tracked but not modified with LoRA.
    """
    logger.info(f"Replacing linear layers with LoRA layers using provided r configuration...")
    
    from functools import reduce
    
    # Keep track of replaced layers
    replaced_layers = {}
    
    # Iterate through named modules
    for name, module in list(model.named_modules()):
        # Only target modules specified in the r_config
        if name in r_config and isinstance(module, nn.Linear):
            r = r_config[name]
            alpha = alpha_config.get(name, 2 * r) if alpha_config else 2 * r
            
            # Skip if r = 0 - clearer logging and handling
            if r <= 0:
                 # The original layer already exists, so keep it as is and just record in replaced_layers
                logger.info(f"SKIPPING LoRA for {name} (r = {r}): Original layer will be preserved")
                replaced_layers[name] = {
                    'r': 0,
                    'alpha': 0,
                    'in_features': module.in_features,
                    'out_features': module.out_features,
                    'applied': False,
                    'trainable_params': 0,  # Since r=0, the number of trainable parameters is 0
                    'message': 'LoRA NOT applied (r=0)'
                }
                # Verify that all parameters in the layer are not set to be trained
                for param_name, param in module.named_parameters():
                    if param.requires_grad:
                        logger.debug(f"  - Parameter {param_name} in r=0 layer will not be trained")
                continue
                
            logger.info(f"Replacing {name} with LoRA (r={r}, alpha={alpha})")
            
            # Pre-save necessary module attributes
            in_features = module.in_features
            out_features = module.out_features
            has_bias = module.bias is not None
            weight_data = module.weight.data.clone()
            bias_data = module.bias.data.clone() if has_bias else None
            
            # Create LoRA layer
            lora_layer = lora.Linear(
                in_features=in_features,
                out_features=out_features,
                r=r,
                lora_alpha=alpha,
                lora_dropout=dropout,
                bias=has_bias,
                merge_weights=merge_weights
            )
            
            # Copy weights from original layer and ensure dtype consistency
            lora_layer.weight.data = weight_data
            if has_bias:
                lora_layer.bias.data = bias_data
            
            # Ensure LoRA parameters have the same dtype as the base weights
            if hasattr(lora_layer, 'lora_A') and hasattr(lora_layer, 'lora_B'):
                lora_layer.lora_A.data = lora_layer.lora_A.data.to(weight_data.dtype)
                lora_layer.lora_B.data = lora_layer.lora_B.data.to(weight_data.dtype)
            
            # Replace the module more efficiently
            names = name.split('.')
            parent_module = reduce(getattr, names[:-1], model)
            setattr(parent_module, names[-1], lora_layer)
            
            del module
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Record the replacement
            replaced_layers[name] = {
                'r': r,
                'alpha': alpha,
                'in_features': in_features,
                'out_features': out_features,
                'applied': True
            }

    applied_count = sum(1 for layer in replaced_layers.values() if layer['applied'])
    skipped_count = sum(1 for layer in replaced_layers.values() if not layer['applied'])
    zero_r_count = sum(1 for layer in replaced_layers.values() if layer['r'] == 0)

    logger.info(f"LoRA Layer Summary: Replaced {applied_count} layers with LoRA, skipped {skipped_count} layers")
    logger.info(f"  - Layers with r=0 (no training): {zero_r_count}")
    logger.info(f"  - Layers with r>0 (with training): {applied_count}")

    for name, info in replaced_layers.items():
        if info['r'] == 0 and info['applied']:
            logger.error(f"CRITICAL ERROR: Layer {name} has r=0 but LoRA was applied!")

    return model, replaced_layers

class OptimalLoRAModel(nn.Module):
    """
    Wrapper around a pre-trained model with LoRA layers optimized for specific tasks.
    """
    
    def __init__(
        self,
        base_model: Union[str, PreTrainedModel],
        r_config: Dict[str, int],
        alpha_config: Optional[Dict[str, int]] = None,
        dropout: float = 0.0,
        merge_weights: bool = False
    ):
        """
        Initialize the OptimalLoRAModel.
        
        Args:
            base_model: Base model or model name
            r_config: Dict mapping layer name to r value
            alpha_config: Dict mapping layer name to alpha value (defaults to 2*r if not provided)
            dropout: Dropout rate for LoRA layers
            merge_weights: Whether to merge weights after training
        """
        super().__init__()
        
        # Load the base model
        if isinstance(base_model, str):
            self.model = AutoModel.from_pretrained(base_model)
        else:
            self.model = base_model
            
        # Replace linear layers with LoRA layers
        self.model, self.replaced_layers = replace_linear_with_lora(
            self.model, r_config, alpha_config, dropout, merge_weights
        )
        
        # Store configurations
        self.r_config = r_config
        self.alpha_config = alpha_config or {name: 2 * r for name, r in r_config.items()}
        
    def forward(self, *args, **kwargs):
        """Forward pass to the underlying model."""
        return self.model(*args, **kwargs)
    
    def save_lora_config(self, save_dir: str):
        """
        Save the LoRA configuration to a file.
        
        Args:
            save_dir: Directory to save the configuration
        """
        os.makedirs(save_dir, exist_ok=True)
        config = {
            'r_config': self.r_config,
            'alpha_config': self.alpha_config,
            'replaced_layers': self.replaced_layers
        }
        
        with open(os.path.join(save_dir, 'lora_config.json'), 'w') as f:
            json.dump(config, f, indent=2)
        
        logger.info(f"Saved LoRA configuration to {save_dir}")
    
    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str,
        lora_config_path: Optional[str] = None,
        **kwargs
    ):
        """
        Load a model with LoRA layers from a saved configuration.
        
        Args:
            model_name_or_path: Path to the base model
            lora_config_path: Path to the LoRA configuration
            **kwargs: Additional arguments to pass to the model constructor
            
        Returns:
            OptimalLoRAModel instance
        """
        # Load base model
        base_model = AutoModel.from_pretrained(model_name_or_path)
        
        # Load LoRA config if provided
        if lora_config_path and os.path.exists(lora_config_path):
            with open(lora_config_path, 'r') as f:
                config = json.load(f)
            
            r_config = config.get('r_config', {})
            alpha_config = config.get('alpha_config', None)
            
            return cls(
                base_model=base_model,
                r_config=r_config,
                alpha_config=alpha_config,
                **kwargs
            )
        
        # If no LoRA config provided, return a model with no LoRA layers
        return cls(
            base_model=base_model,
            r_config={},
            **kwargs
        )
    
    def mark_only_lora_as_trainable(self):
        """Mark only LoRA parameters as trainable."""
        lora.mark_only_lora_as_trainable(self.model)
        
        # Count trainable parameters
        lora_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.model.parameters())
        
        logger.info(f"Trainable parameters: {lora_params:,} ({lora_params/total_params:.2%} of total)")
        
        return self
    
# Modifications to models/optimal_lora.py

# 1. Update prepare_bert_for_glue function to store initial r_config

def prepare_model_for_glue(
    base_model_name: str,
    r_config: Dict[str, int],
    num_labels: int,
    dropout: float = 0.1,
    tokenizer=None,
):
    """
    Prepare a model with optimal LoRA configuration for GLUE tasks.
    
    Args:
        base_model_name: Name of the base model
        r_config: Dict mapping layer name to r value
        num_labels: Number of labels for the classification head
        dropout: Dropout probability
        tokenizer: Tokenizer for setting pad_token_id (required for LLaMA)
        
    Returns:
        Model with LoRA layers and classification head
    """
    # Check if this is a LLaMA model
    if "llama" in base_model_name.lower():
        from transformers import LlamaForSequenceClassification
        
        logger.info(f"Loading LLaMA model: {base_model_name}")
        
        # Get pad_token_id from tokenizer
        pad_token_id = None
        if tokenizer is not None:
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id
            pad_token_id = tokenizer.pad_token_id
        
        # Load base model with classification head
        model = LlamaForSequenceClassification.from_pretrained(
            base_model_name,
            num_labels=num_labels,
            torch_dtype=torch.bfloat16,  # Use bfloat16 for memory efficiency
            device_map="auto",  # Automatically handle device placement
            pad_token_id=pad_token_id
        )
        
        # Replace linear layers with LoRA layers
        model, replaced_layers = replace_linear_with_lora(
            model, r_config, alpha_config=None, dropout=dropout, merge_weights=True
        )
        
        # Mark only LoRA parameters as trainable
        lora.mark_only_lora_as_trainable(model)
        
        # Also make classification head trainable
        if hasattr(model, 'score'):
            for param in model.score.parameters():
                param.requires_grad = True
    else:
        # Original code for BERT/RoBERTa models
        from transformers import AutoModelForSequenceClassification
        
        # Load base model with classification head
        model = AutoModelForSequenceClassification.from_pretrained(
            base_model_name,
            num_labels=num_labels
        )
        
        # Replace linear layers with LoRA layers
        model, replaced_layers = replace_linear_with_lora(
            model, r_config, alpha_config=None, dropout=dropout, merge_weights=True
        )
        
        # Mark only LoRA parameters as trainable
        lora.mark_only_lora_as_trainable(model)
    
    # Count trainable parameters
    lora_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    logger.info(f"Trainable parameters: {lora_params:,} ({lora_params/total_params:.2%} of total)")
    
    # Store initial r_config as attribute for later reference (used in pruning)
    model.initial_r_config = r_config.copy()
    
    return model, replaced_layers


def prepare_model_for_alpaca(
    base_model_name: str,
    r_config: Dict[str, int],
    dropout: float = 0.1,
    tokenizer=None,
):
    """
    Prepare a model with optimal LoRA configuration for Alpaca instruction-following task.
    
    Args:
        base_model_name: Name of the base model
        r_config: Dict mapping layer name to r value
        dropout: Dropout probability
        tokenizer: Tokenizer for setting pad_token_id (required for LLaMA)
        
    Returns:
        Model with LoRA layers for generation task
    """
    # Check if this is a LLaMA model
    if "llama" in base_model_name.lower():
        from transformers import LlamaForCausalLM
        
        logger.info(f"Loading LLaMA model for Alpaca task: {base_model_name}")
        
        # Get pad_token_id from tokenizer
        pad_token_id = None
        if tokenizer is not None:
            if tokenizer.pad_token_id is None:
                tokenizer.pad_token = tokenizer.eos_token
                tokenizer.pad_token_id = tokenizer.eos_token_id
            pad_token_id = tokenizer.pad_token_id
        
        # Load base model for causal language modeling
        model = LlamaForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.bfloat16,  # Use bfloat16 for memory efficiency
            device_map="auto",  # Automatically handle device placement
            pad_token_id=pad_token_id
        )
        
        # Replace linear layers with LoRA layers
        model, replaced_layers = replace_linear_with_lora(
            model, r_config, alpha_config=None, dropout=dropout, merge_weights=True
        )
        
        # Mark only LoRA parameters as trainable
        lora.mark_only_lora_as_trainable(model)
        
        # For generation tasks, we don't need to make any head trainable
        # as the language modeling head shares weights with embeddings
    else:
        # Original code for other models (GPT-2, BERT, etc.)
        from transformers import AutoModelForCausalLM
        
        # Load base model for causal language modeling
        model = AutoModelForCausalLM.from_pretrained(
            base_model_name
        )
        
        # Replace linear layers with LoRA layers
        model, replaced_layers = replace_linear_with_lora(
            model, r_config, alpha_config=None, dropout=dropout, merge_weights=True
        )
        
        # Mark only LoRA parameters as trainable
        lora.mark_only_lora_as_trainable(model)
    
    # Count trainable parameters
    lora_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    
    logger.info(f"Trainable parameters: {lora_params:,} ({lora_params/total_params:.2%} of total)")
    
    # Store initial r_config as attribute for later reference (used in pruning)
    model.initial_r_config = r_config.copy()
    
    return model, replaced_layers




def get_default_llama_r_config(model_name: str = "meta-llama/Meta-Llama-3-8B", default_r: int = 8):
    """
    Get default r configuration for LLaMA model.
    
    Args:
        model_name: Model name
        default_r: Default rank value
        
    Returns:
        Dictionary mapping layer names to r values
    """
    # Create a temporary model to get layer names
    from transformers import LlamaModel
    temp_model = LlamaModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
    
    r_config = {}
    
    for name, module in temp_model.named_modules():
        if isinstance(module, nn.Linear):
            # Apply LoRA to attention and MLP layers
            if any(key in name for key in ['q_proj', 'k_proj', 'v_proj', 'o_proj', 
                                           'gate_proj', 'up_proj', 'down_proj']):
                r_config[name] = default_r
    
    del temp_model
    torch.cuda.empty_cache()
    
    return r_config