import logging
import os
import sys
from typing import Dict, Any, Optional, Tuple, Union
import json
import matplotlib.pyplot as plt
import numpy as np
import time
import torch

def ensure_lora_merged_state(model: torch.nn.Module) -> Tuple[bool, Dict[str, bool]]:
    """
    Ensure all LoRA layers are in merged state for optimal inference performance.
    Returns original states to restore later.
    
    Args:
        model: The model to optimize for inference
        
    Returns:
        Tuple of (has_lora_layers, original_merge_states)
    """
    has_lora_layers = False
    original_states = {}
    
    for name, module in model.named_modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            has_lora_layers = True
            if hasattr(module, 'merged'):
                original_states[name] = module.merged
                # Force merge if not already merged and r > 0
                if not module.merged and hasattr(module, 'r') and module.r > 0:
                    # Set to eval mode to merge weights
                    module.eval()
    
    return has_lora_layers, original_states

def restore_lora_merge_states(model: torch.nn.Module, original_states: Dict[str, bool]):
    """
    Restore LoRA layers to their original merge states.
    
    Args:
        model: The model to restore
        original_states: Dictionary of original merge states
    """
    for name, module in model.named_modules():
        if name in original_states:
            if hasattr(module, 'merged'):
                # If originally unmerged but now merged, unmerge it
                if not original_states[name] and module.merged:
                    module.train()  # This will unmerge
                # If originally merged but now unmerged, merge it
                elif original_states[name] and not module.merged:
                    module.eval()  # This will merge

def measure_batch_inference_time(
    model: torch.nn.Module,
    batch: Dict[str, torch.Tensor],
    device: Optional[torch.device] = None,
    return_outputs: bool = False
) -> Union[float, Tuple[float, Any]]:
    """
    Precisely measure model inference time for a single batch.
    Fixed version that handles LoRA models properly without profiler conflicts.
    
    This function measures the pure model execution time without any data loading
    or processing overheads. It ensures proper GPU synchronization for accurate timing.
    
    Args:
        model: The model to time
        batch: Input batch dictionary with tensors
        device: Device for synchronization (defaults to model's device)
        return_outputs: Whether to return model outputs along with time
        
    Returns:
        If return_outputs=False: Inference time in seconds
        If return_outputs=True: Tuple of (inference_time, model_outputs)
    """
    # Store original model state
    was_training = model.training
    
    # Set model to eval mode - this will merge LoRA weights
    model.eval()
    
    # Wait for any previous operations to complete
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # If device not specified, use model's device
    if device is None and hasattr(model, 'device'):
        device = model.device
    elif device is None:
        device = next(model.parameters()).device
    
    # Move batch to device if needed
    if device != torch.device('cpu'):
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}
    
    # Warm-up run to ensure CUDA kernels are initialized
    with torch.no_grad():
        _ = model(**batch)
    
    # Ensure proper synchronization before timing
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # Start timer
    start_time = time.perf_counter()
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**batch)
    
    # Ensure all computation is complete
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    # End timer
    end_time = time.perf_counter()
    inference_time = end_time - start_time
    
    # Restore model's original state
    if was_training:
        model.train()
    
    # Return time along with outputs if requested
    if return_outputs:
        return inference_time, outputs
    return inference_time

def measure_evaluation_time(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    device: Optional[torch.device] = None,
    max_batches: Optional[int] = None,
    log_interval: int = 10,
    logger: Optional[logging.Logger] = None
) -> Dict[str, Any]:
    """
    Measure model evaluation time across a dataloader with precise timing.
    Fixed version that properly handles LoRA models.
    
    This function measures pure model execution time for each batch in the dataloader
    and provides detailed timing statistics.
    
    Args:
        model: The model to evaluate
        dataloader: DataLoader with evaluation batches
        device: Device for synchronization
        max_batches: Maximum number of batches to process (None for all)
        log_interval: How often to log progress
        logger: Logger for output (uses print if None)
        
    Returns:
        Dictionary with timing statistics:
        - total_time: Total inference time (seconds)
        - avg_batch_time: Average time per batch (seconds)
        - batch_times: List of individual batch times
        - num_batches: Number of batches processed
        - steps_per_second: Batches processed per second
    """
    # Initialize timing storage
    batch_times = []
    total_time = 0.0
    total_samples = 0
    
    # Store original model state
    was_training = model.training
    model.eval()
    
    # Log LoRA status if present
    has_lora = False
    for module in model.modules():
        if hasattr(module, 'lora_A') and hasattr(module, 'lora_B'):
            has_lora = True
            break
    
    if logger and has_lora:
        logger.info("Model contains LoRA layers - ensuring optimal inference state")
    
    # Process batches
    for batch_idx, batch in enumerate(dataloader):
        # Stop if max_batches reached
        if max_batches is not None and batch_idx >= max_batches:
            break
            
        # Log progress
        if logger and batch_idx % log_interval == 0:
            logger.info(f"Processing batch {batch_idx+1}/{len(dataloader)}")
            
        # Measure time for this batch
        batch_time = measure_batch_inference_time(model, batch, device)
        batch_times.append(batch_time)
        total_time += batch_time
        
        # Track sample count if available
        if "input_ids" in batch:
            total_samples += batch["input_ids"].size(0)
            
    # Calculate statistics
    num_batches = len(batch_times)
    avg_batch_time = sum(batch_times) / num_batches if num_batches > 0 else 0
    steps_per_second = num_batches / total_time if total_time > 0 else 0
    
    # Log results if logger provided
    if logger:
        logger.info(f"Evaluation completed: {num_batches} batches in {total_time:.4f}s")
        logger.info(f"Average batch time: {avg_batch_time:.6f}s")
        logger.info(f"Processing speed: {steps_per_second:.2f} batches/second")
        if has_lora:
            logger.info("Note: LoRA layers were in optimal state for inference")
        
    # Restore model state
    if was_training:
        model.train()
        
    # Return timing information
    return {
        "total_time": total_time,
        "avg_batch_time": avg_batch_time,
        "batch_times": batch_times,
        "num_batches": num_batches,
        "steps_per_second": steps_per_second,
        "total_samples": total_samples,
        "samples_per_second": total_samples / total_time if total_time > 0 else 0
    }
    
def setup_logger(name: str, log_file: Optional[str] = None, level=logging.INFO):
    """
    Set up a logger that outputs to both console and file, with system environment info.
    
    Args:
        name: Name of the logger
        log_file: Path to log file
        level: Logging level
        
    Returns:
        Logger object
    """
    # Create logger
    logger = logging.getLogger(name)
    logger.setLevel(level)
    
    # Remove existing handlers to avoid duplicate logs
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)
    
    # Create console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(level)
    console_formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    console_handler.setFormatter(console_formatter)
    logger.addHandler(console_handler)
    
    # Create file handler if specified
    if log_file:
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        file_handler = logging.FileHandler(log_file)
        file_handler.setLevel(level)
        file_formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)
    
    # Log system environment information for reproducibility
    if level <= logging.INFO:
        try:
            import platform
            import torch
            import numpy as np
            
            logger.info("=" * 80)
            logger.info("SYSTEM ENVIRONMENT INFORMATION FOR REPRODUCIBILITY")
            logger.info("=" * 80)
            logger.info(f"Python version: {platform.python_version()}")
            logger.info(f"OS: {platform.platform()}")
            logger.info(f"CPU: {platform.processor()}")
            
            if torch.cuda.is_available():
                logger.info(f"CUDA version: {torch.version.cuda}")
                logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
            
            logger.info(f"PyTorch version: {torch.__version__}")
            logger.info(f"NumPy version: {np.__version__}")
            
            try:
                import pulp
                logger.info(f"PuLP version: {pulp.__version__}")
            except (ImportError, AttributeError):
                logger.info("PuLP not available")
                
            logger.info("=" * 80)
            
        except Exception as e:
            logger.warning(f"Error logging system information: {e}")
    
    return logger

def log_optimal_r_config(
    logger: logging.Logger,
    optimal_r: Dict[str, int],
    optimization_results: Dict[str, Any],
    output_dir: str
):
    """
    Log the optimal r configuration.
    
    Args:
        logger: Logger object
        optimal_r: Dict mapping layer_name to optimal r value
        optimization_results: Dict containing detailed optimization results
        output_dir: Directory to save plots and results
    """
    # Extract optimization seed if available
    seed = optimization_results.get("seed", None)
    seed_info = f" (seed: {seed})" if seed is not None else ""
    logger.info(f"Optimal r configuration{seed_info}:")
    
    # Create a sorted version for nice display
    sorted_layers = sorted(optimal_r.keys())
    for layer_name in sorted_layers:
        r = optimal_r[layer_name]
        logger.info(f"  {layer_name}: r={r}")
    
    # Save the configuration to a file
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, "optimal_r_config.json"), "w") as f:
        json.dump(optimal_r, f, indent=2)
    
    # Save a separate configuration with reproducibility metadata for traceability
    reproducible_config = {
        "optimal_r": optimal_r,
        "reproducibility": {
            "seed": optimization_results.get("seed", None),
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
            "r_values": optimization_results.get("r_values", []),
            "budget": optimization_results.get("budget", 0)
        }
    }
    
    with open(os.path.join(output_dir, "reproducible_r_config.json"), "w") as f:
        json.dump(reproducible_config, f, indent=2)
    
    logger.info(f"Saved reproducible configuration to {os.path.join(output_dir, 'reproducible_r_config.json')}")
    
    # Create a bar chart of the r values
    plt.figure(figsize=(12, 8))
    layers = list(optimal_r.keys())
    r_values = list(optimal_r.values())
    
    # Use shorter layer names for readability
    short_layer_names = [name.split('.')[-1] if '.' in name else name for name in layers]
    
    # Sort by layer type
    layer_types = []
    for name in layers:
        if "query" in name:
            layer_types.append("query")
        elif "key" in name:
            layer_types.append("key")
        elif "value" in name:
            layer_types.append("value")
        elif "attention" in name:
            layer_types.append("attention")
        elif "intermediate" in name:
            layer_types.append("ffn")
        elif "output" in name:
            layer_types.append("output")
        else:
            layer_types.append("other")
    
    # Create sorted indices
    sorted_indices = np.argsort(layer_types)
    sorted_layers = [layers[i] for i in sorted_indices]
    sorted_r_values = [r_values[i] for i in sorted_indices]
    sorted_short_names = [short_layer_names[i] for i in sorted_indices]
    sorted_layer_types = [layer_types[i] for i in sorted_indices]
    
    # Create color map
    color_map = {
        "query": "blue",
        "key": "green",
        "value": "red",
        "attention": "purple",
        "ffn": "orange",
        "output": "brown",
        "other": "gray"
    }
    colors = [color_map[t] for t in sorted_layer_types]
    
    plt.figure(figsize=(15, 10))
    bars = plt.bar(range(len(sorted_layers)), sorted_r_values, color=colors)
    plt.xticks(range(len(sorted_layers)), sorted_short_names, rotation=90)
    
    # Add a legend
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=color_map[t], label=t) 
                      for t in sorted(set(layer_types))]
    plt.legend(handles=legend_elements)
    
    plt.xlabel('Layer')
    plt.ylabel('Optimal r value')
    plt.title('Optimal r configuration')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "optimal_r_config.png"))
    
    # Create a cost-gain scatter plot
    plt.figure(figsize=(10, 6))
    
    layer_gains = []
    layer_costs = []
    r_value_indices = {r: i for i, r in enumerate(optimization_results["r_values"])}
    
    for layer_name in sorted_layers:
        r = optimal_r[layer_name]
        r_idx = r_value_indices[r]
        gain = optimization_results["gains"][layer_name][r_idx]
        cost = optimization_results["costs"][layer_name][r_idx]
        layer_gains.append(gain)
        layer_costs.append(cost)
    
    plt.scatter(layer_costs, layer_gains, c=colors, alpha=0.7)
    
    # Add layer names as annotations
    for i, layer_name in enumerate(sorted_short_names):
        plt.annotate(layer_name, (layer_costs[i], layer_gains[i]),
                    textcoords="offset points", xytext=(0,5), ha='center')
    
    plt.xlabel('Cost')
    plt.ylabel('Gain')
    plt.title('Cost vs. Gain for Each Layer')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "cost_gain_scatter.png"))
    
    # Save all optimization results for further analysis
    with open(os.path.join(output_dir, "optimization_results.json"), "w") as f:
        # Convert numpy arrays to lists for JSON serialization
        results_json = {}
        for key, value in optimization_results.items():
            if key in ["gains", "costs"]:
                results_json[key] = {k: [float(v) for v in values] for k, values in value.items()}
            elif key == "importances":
                results_json[key] = {k: float(v) for k, v in value.items()}
            else:
                results_json[key] = value
        
        json.dump(results_json, f, indent=2)