import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import logging
from typing import Dict, List, Optional, Union, Tuple, Any
import os
import json

from pruning.initial_rank_allocation import get_optimal_r_config
from utils.logging_utils import log_optimal_r_config

logger = logging.getLogger(__name__)

class LoRAOptimizer:
    """
    Optimizer for finding the optimal LoRA rank configuration.
    """
    
    def __init__(
        self,
        model: nn.Module,
        r_values: List[int],
        budget: float,
        device: torch.device,
        output_dir: str,
        seed: int = 42
    ):
        """
        Initialize the LoRA Optimizer.
        
        Args:
            model: The model to optimize
            r_values: List of r values to consider
            budget: Total budget constraint
            device: Device to run the model on
            output_dir: Directory to save results
            seed: Random seed for reproducibility
        """
        self.model = model
        self.r_values = r_values
        self.budget = budget
        self.device = device
        self.output_dir = output_dir
        self.seed = seed
        
        # Create output directory
        os.makedirs(output_dir, exist_ok=True)
        
        # Log the seed for reproducibility tracking
        logger.info(f"Initialized LoRA Optimizer with seed {seed}")
    
    def find_target_layers(self, layer_patterns: List[str] = None) -> List[str]:
        """
        Find all linear layers in the model that match the given patterns.
        
        Args:
            layer_patterns: List of regex patterns to match layer names
                           If None, all Linear layers are considered
                           
        Returns:
            List of layer names
        """
        import re
        
        # Check if this is a LLaMA model
        model_type = None
        if hasattr(self.model, 'config') and hasattr(self.model.config, 'model_type'):
            model_type = self.model.config.model_type.lower()
        
        if layer_patterns is None:
            if model_type == 'llama':
                # LLaMA specific patterns
                layer_patterns = ['.*q_proj', '.*k_proj', '.*v_proj', '.*o_proj', 
                                '.*gate_proj', '.*up_proj', '.*down_proj']
            else:
                # BERT/RoBERTa patterns
                layer_patterns = ['.*query', '.*key', '.*value', '.*attention.output', '.*intermediate', '.*output']
        
        target_layers = []
        
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                for pattern in layer_patterns:
                    if re.match(pattern, name):
                        target_layers.append(name)
                        break
        
        logger.info(f"Found {len(target_layers)} target layers for optimization")
        for i, layer in enumerate(sorted(target_layers)):
            logger.info(f"  {i+1}. {layer}")
        
        return target_layers
    
    def optimize(
        self,
        dataloader: DataLoader,
        target_layers: Optional[List[str]] = None,
        seed: Optional[int] = None
    ) -> Dict[str, int]:
        """
        Find the optimal r configuration.
        
        Args:
            dataloader: Dataloader for importance estimation
            target_layers: List of layer names to optimize, 
                        if None, will be automatically detected
            seed: Random seed for reproducibility (overrides instance seed)
                        
        Returns:
            Dict mapping layer_name to optimal r value
        """
        # Find target layers if not provided
        if target_layers is None:
            target_layers = self.find_target_layers()
            
        # Use provided seed or fall back to instance seed
        actual_seed = seed if seed is not None else self.seed
        logger.info(f"Optimizing with seed {actual_seed}")
            
        # Get optimal r configuration
        logger.info("Starting optimization process...")
        optimal_r, optimization_results = get_optimal_r_config(
            model=self.model,
            dataloader=dataloader,
            r_values=self.r_values,
            target_layers=target_layers,
            budget=self.budget,
            device=self.device,
            seed=actual_seed
        )
        
        # Save seed in optimization results
        optimization_results["seed"] = actual_seed
        
        # Log results
        log_optimal_r_config(
            logger=logger,
            optimal_r=optimal_r,
            optimization_results=optimization_results,
            output_dir=self.output_dir
        )
        
        # Save the seed specifically for reproducibility
        with open(os.path.join(self.output_dir, "optimization_seed.json"), "w") as f:
            json.dump({"seed": actual_seed}, f, indent=2)
        
        return optimal_r
    
    @staticmethod
    def get_r_config_from_file(config_path: str) -> Dict[str, int]:
        """
        Load r configuration from a file.
        
        Args:
            config_path: Path to the JSON file containing the r configuration
            
        Returns:
            Dict mapping layer_name to r value
        """
        with open(config_path, 'r') as f:
            r_config = json.load(f)
        
        return r_config
    
    @staticmethod
    def convert_to_peft_config(r_config: Dict[str, int], prefix: str = '') -> Dict[str, Any]:
        """
        Convert our r configuration format to PEFT's LoRA config format.
        
        Args:
            r_config: Our r configuration format
            prefix: Prefix to add to layer names
            
        Returns:
            Dict in PEFT's format
        """
        import re
        
        # Default values
        lora_alpha = 16
        lora_dropout = 0.1
        
        # Process target modules
        target_modules = {}
        for layer_name, r in r_config.items():
            # Extract module name (usually the last part of the layer name)
            module_name = layer_name.split('.')[-1]
            
            # Group by layer type
            if 'query' in module_name:
                target_modules.setdefault('q_proj', []).append((layer_name, r))
            elif 'key' in module_name:
                target_modules.setdefault('k_proj', []).append((layer_name, r))
            elif 'value' in module_name:
                target_modules.setdefault('v_proj', []).append((layer_name, r))
            elif 'attention.output' in layer_name:
                target_modules.setdefault('o_proj', []).append((layer_name, r))
            elif 'intermediate' in layer_name:
                target_modules.setdefault('up_proj', []).append((layer_name, r))
            elif 'output' in layer_name and 'attention' not in layer_name:
                target_modules.setdefault('down_proj', []).append((layer_name, r))
        
        # Create config in PEFT format
        peft_config = {
            'peft_type': 'LORA',
            'task_type': 'SEQ_CLS',
            'target_modules': list(target_modules.keys()),
            'inference_mode': False,
            'r': {module: max(r for _, r in entries) for module, entries in target_modules.items()},
            'lora_alpha': {module: lora_alpha for module in target_modules},
            'lora_dropout': lora_dropout
        }
        
        return peft_config