import torch
import numpy as np
import pulp
import logging
from typing import Dict, List, Tuple, Optional, Any, Union
import time
import os
import gc

logger = logging.getLogger(__name__)

def measure_layer_importance(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    target_layers: List[str],
    device: torch.device,
) -> Dict[str, float]:
    """
    Measure the importance of each layer based on Fisher information.
    
    Args:
        model: The model to analyze
        dataloader: Small batch of data for importance estimation
        target_layers: List of layer names to analyze
        device: Device to run the model on
        
    Returns:
        Dict mapping layer_name to importance score
    """
    logger.info("Measuring layer importance...")
    model.to(device)
    model.eval()
    
    # Enable gradient checkpointing if available to save memory
    if hasattr(model, 'gradient_checkpointing_enable'):
        model.gradient_checkpointing_enable()
        logger.info("Gradient checkpointing enabled for memory efficiency")
    
    # Initialize importances dictionary
    importances = {name: 0.0 for name in target_layers}
    
    # Track all linear layers in the model
    linear_layers = {}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) and name in target_layers:
            linear_layers[name] = module
    
    # Calculate parameter gradients for importance estimation
    model.zero_grad()
    for batch_idx, batch in enumerate(dataloader):
        if batch_idx >= 5:  # Keep 5 batches as required
            break
        
        # Aggressive memory cleanup before processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            gc.collect()
            
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
        
        # Process in smaller chunks to avoid OOM
        batch_size = batch['input_ids'].size(0)
        chunk_size = 2  # Process in very small chunks of 2
        total_loss = 0
        num_chunks = 0
        
        for i in range(0, batch_size, chunk_size):
            end_idx = min(i + chunk_size, batch_size)
            chunk_batch = {k: v[i:end_idx] for k, v in batch.items()}
            
            # Forward pass with gradient checkpointing (mixed precision removed)
            outputs = model(**chunk_batch)
            loss = outputs.loss
            
            if loss is not None:
                total_loss += loss
                num_chunks += 1
            
            del outputs
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        if num_chunks > 0:
            loss = total_loss / num_chunks
        else:
            loss = None
        

        if loss is None:

            if hasattr(outputs, 'logits') and outputs.logits is not None:
                if outputs.logits.dim() > 1 and outputs.logits.size(-1) > 1:  

                    pseudo_targets = torch.zeros(outputs.logits.size(0), dtype=torch.long, device=outputs.logits.device)
                    loss = torch.nn.functional.cross_entropy(outputs.logits, pseudo_targets)
                else: 

                    loss = torch.mean(outputs.logits ** 2)
            else:

                logger.warning("Unable to compute loss: outputs has no loss or logits attribute")
                continue


        # Backward pass to get gradients
        if loss is not None:
            loss.backward()
        
        # Calculate importance based on gradients
        for name, layer in linear_layers.items():
            if hasattr(layer, 'weight') and layer.weight.grad is not None:
                # Fisher Information approximation: gradient^2
                importance = layer.weight.grad.abs().mean().item()
                importances[name] += importance
    
    # Average importance across batches
    for name in importances:
        importances[name] /= min(5, len(dataloader))
    
    # Log importances
    for name, importance in sorted(importances.items(), key=lambda x: x[1], reverse=True):
        logger.info(f"Layer {name}: importance = {importance:.6f}")
    
    return importances

def estimate_performance_gain(
    r_value: int,
    layer_importance: float,
    layer_size: Tuple[int, int],
) -> float:
    """
    Estimate the performance gain for a given r value on a layer.
    
    Args:
        r_value: Rank value for LoRA
        layer_importance: Importance score of the layer
        layer_size: Tuple of (in_features, out_features) for the layer
        
    Returns:
        Estimated performance gain
    """
    # Estimate performance gain - this is a heuristic function
    # Higher r and more important layers should have higher performance gain
    in_features, out_features = layer_size
    max_rank = min(in_features, out_features)
    # Normalized rank (0 to 1)
    normalized_rank = r_value / max_rank
    
    # Calculate gain. We assume diminishing returns with higher r values
    gain = layer_importance * (1 - np.exp(-1 * normalized_rank))
    return gain

def estimate_computational_cost(
    r_value: int,
    layer_size: Tuple[int, int],
) -> float:
    """
    Estimate the computational cost of a LoRA layer.
    
    Args:
        r_value: Rank value for LoRA
        layer_size: Tuple of (in_features, out_features) for the layer
        
    Returns:
        Estimated computational cost
    """
    in_features, out_features = layer_size
    
    # Cost of forward pass:
    # - Computing lora_A @ x: r * in_features operations
    # - Computing lora_B @ (lora_A @ x): r * out_features operations
    # Also consider parameter count as part of the cost
    forward_cost = r_value * (in_features + out_features)
    
    # Parameter count (we might want to give this a different weight)
    param_count = r_value * (in_features + out_features)
    
    # Total cost (you can adjust the weights of these components)
    total_cost = forward_cost + param_count
    
    return total_cost

def prepare_optimization_data(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    target_layers: List[str],
    r_values: List[int],
    device: torch.device,
) -> Tuple[Dict[str, List[float]], Dict[str, List[float]]]:
    """
    Prepare data for the optimization problem.
    
    Args:
        model: The model to optimize
        dataloader: Dataloader for importance estimation
        target_layers: List of layer names to optimize
        r_values: List of r values to consider
        device: Device to run the model on
        
    Returns:
        Tuple of (gains, costs) dicts, mapping each layer to lists of values for each r
    """
    logger.info(f"Preparing optimization data for {len(target_layers)} layers and {len(r_values)} r values...")
    
    # Measure layer importance
    importances = measure_layer_importance(model, dataloader, target_layers, device)
    
    # Collect layer dimensions
    layer_sizes = {}
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) and name in target_layers:
            layer_sizes[name] = (module.in_features, module.out_features)
    
    # Calculate gains and costs for each layer and r value
    gains = {layer_name: [] for layer_name in target_layers}
    costs = {layer_name: [] for layer_name in target_layers}
    
    for layer_name in target_layers:
        if layer_name not in layer_sizes:
            logger.warning(f"Layer {layer_name} not found in model, skipping...")
            continue
            
        importance = importances.get(layer_name, 0.001)  # Default small value if not found
        layer_size = layer_sizes[layer_name]
        
        for r in r_values:
            gain = estimate_performance_gain(r, importance, layer_size)
            cost = estimate_computational_cost(r, layer_size)
            
            gains[layer_name].append(gain)
            costs[layer_name].append(cost)
    
    return gains, costs

def optimize_r_values(
    gains: Dict[str, List[float]],
    costs: Dict[str, List[float]],
    r_values: List[int],
    budget: float,
    time_limit: int = None,
    solver="pulp_cbc",
    seed: int = 42
) -> Dict[str, int]:
    """
    Optimize r values using integer linear programming.
    
    Mathematical formulation:
    
    Maximize: Σ(layer_i, r_j) gain[layer_i][r_j] * x[layer_i][r_j]
    Subject to:
      - Σ(r_j) x[layer_i][r_j] = 1 for all layer_i  (one r value per layer)
      - Σ(layer_i, r_j) cost[layer_i][r_j] * x[layer_i][r_j] ≤ budget
      - x[layer_i][r_j] ∈ {0, 1}
    
    Where:
      - x[layer_i][r_j] is a binary variable indicating if layer_i uses rank r_j
      - gain[layer_i][r_j] is the estimated performance gain for using rank r_j on layer_i
      - cost[layer_i][r_j] is the computational/parameter cost for using rank r_j on layer_i
    
    Args:
        gains: Dict mapping layer_name to list of gain estimates for each r value
        costs: Dict mapping layer_name to list of cost estimates for each r value
        r_values: List of r values corresponding to the indices in gains and costs
        budget: Total budget constraint
        time_limit: Time limit for optimization in seconds
        solver: Optimization solver to use ("pulp_cbc" recommended)
        seed: Random seed for reproducibility
        
    Returns:
        Dict mapping layer_name to optimal r value
    """
    try:
        import pulp
    except ImportError:
        logger.error("PuLP is required for this optimization. Please install pulp.")
        logger.error("If you cannot install PuLP, consider using a simple greedy algorithm as fallback.")
        raise ImportError("PuLP optimizer is required for reliable optimization.")
            
    start_time = time.time()
    logger.info(f"Starting r-value optimization with budget {budget:.2f}...")
    
    layer_names = list(gains.keys())

    if time_limit is None:
        time_limit = min(600, max(60, len(layer_names) * 5))
    
    logger.info(f"Setting optimization time limit to {time_limit} seconds and seed to {seed}")

    # Ensure reproducibility at environment level

    os.environ["CBC_RANDOM_SEED"] = str(seed)
    np.random.seed(seed)

    # Create a new model - maximize objective
    ilp_model = pulp.LpProblem("optimal_r", pulp.LpMaximize)
    
    # Set extensive solver options for reproducibility
    # These options aim to replicate the Gurobi parameters for deterministic behavior
    solver_options = [
        f"randomSeed {seed}",             # Set random seed for reproducibility
        f"timeLimit {time_limit}",         # Time limit in seconds
        "threads 1",                       # Use single thread for deterministic behavior
        "ratioGap 0.0",                    # Set MIP gap tolerance to 0 (exact solution)
        "allowableGap 0.0",                # Set allowable absolute gap to 0
        "presolve on",                     # Enable presolve (like Gurobi Presolve=2)
        "passPresolve 5",                  # More aggressive presolve (similar to Gurobi)
        "cutoff 1e50",                     # High cutoff value
        "sec 3600",                        # Maximum time per node in seconds
        "strong 10",                       # Strong branching on 10 variables
        "perturbation on",                 # Enable perturbation for stability
        "passC 1000",                      # Pass limit for cut generator
        "cuts on",                         # Enable cut generation (similar to Gurobi Cuts=2)
        "passCuts 10",                     # Number of cut passes
        "cost off",                        # Disable automatic computation for priorities
        "primalP on",                      # Enable primal heuristics
        "logLevel 1",                      # Standard log level (like Gurobi OutputFlag=1)
        "nodeStrategy depth",              # Node selection strategy
        "scaling aggressive",              # Aggressive scaling
        "integerT 1e-9",                   # Integer tolerance (like Gurobi IntFeasTol)
        "primalT 1e-9",                    # Primal tolerance (like Gurobi FeasibilityTol)
        "dualT 1e-9",                      # Dual tolerance
        "OrbitalBranching on",             # Enable orbital branching for symmetry detection
        "prioritize on",                   # Prioritize important variables
        "autoScale on",                    # Auto scaling
    ]
    
    # Create binary decision variables for each layer and r value
    variables = {}
    for layer_name in layer_names:
        for i, r in enumerate(r_values):
            variables[(layer_name, i)] = pulp.LpVariable(
                f"{layer_name}_r{r}",
                cat=pulp.LpBinary
            )
    
    # Constraint: Each layer must choose exactly one r value
    for layer_name in layer_names:
        ilp_model += (
            pulp.lpSum(variables[(layer_name, i)] for i in range(len(r_values))) == 1,
            f"one_r_{layer_name}"
        )
    
    # Constraint: Total cost must not exceed budget
    # Group by layer type
    layer_types = {
        "query": [],
        "key": [],
        "value": [],
        "attention_output": [],
        "intermediate": [],
        "output": []
    }

    # Classify layers by type - adapted for both BERT/RoBERTa and LLama models
    for layer_name in layer_names:
        if "query" in layer_name or "q_proj" in layer_name:
            layer_types["query"].append(layer_name)
        elif "key" in layer_name or "k_proj" in layer_name:
            layer_types["key"].append(layer_name)
        elif "value" in layer_name or "v_proj" in layer_name:
            layer_types["value"].append(layer_name)
        elif "attention.output" in layer_name or "o_proj" in layer_name:
            layer_types["attention_output"].append(layer_name)
        elif "intermediate" in layer_name or "gate_proj" in layer_name or "up_proj" in layer_name:
            layer_types["intermediate"].append(layer_name)
        elif ("output" in layer_name and "attention" not in layer_name) or "down_proj" in layer_name:
            layer_types["output"].append(layer_name)

    # Constraint: Average rank per layer type must be greater than 0
    for layer_type, type_layers in layer_types.items():
        if not type_layers:
            continue
            
        # Sum of r values for all layers of this type
        sum_r_expr = pulp.lpSum(
            r_values[i] * variables[(layer_name, i)] 
            for layer_name in type_layers for i in range(len(r_values))
        )
        
        # Average r value for all layers of this type must be > 0
        # This is equivalent to sum of r values > 0
        ilp_model += (
            sum_r_expr >= 0.001,  # Use small epsilon to ensure > 0
            f"avg_r_positive_{layer_type}"
        )
    
    # Total cost constraint
    total_cost = pulp.lpSum(
        costs[layer_name][i] * variables[(layer_name, i)]
        for layer_name in layer_names
        for i in range(len(r_values))
    )
    
    ilp_model += (total_cost <= budget, "budget_constraint")
    
    # Objective: Maximize total gain
    objective = pulp.lpSum(
        gains[layer_name][i] * variables[(layer_name, i)]
        for layer_name in layer_names
        for i in range(len(r_values))
    )
    
    ilp_model += objective
    
    # Solve the model with enhanced reproducibility settings
    solver = pulp.PULP_CBC_CMD(
        msg=False,  
        timeLimit=time_limit, 
        options=solver_options,
        keepFiles=False,  # Keep files for debugging
        mip=True,        # Force use of MIP solver
        threads=1,       # Redundant but explicit single thread setting
        gapRel=0.0,      # Relative gap tolerance
        gapAbs=0.0,      # Absolute gap tolerance
    )
    
    # Add a deterministic solving message
    logger.info(f"Starting deterministic optimization with CBC solver (seed: {seed}, threads: 1)")
    
    # Solve with additional time tracking
    solve_start = time.time()
    ilp_model.solve(solver)
    solve_time = time.time() - solve_start
    
    # Check if the model was solved successfully
    status = pulp.LpStatus[ilp_model.status]
    if status == 'Optimal':
        logger.info(f"Optimal solution found in {solve_time:.2f} seconds!")
    elif status == 'Not Solved':
        logger.warning(f"Time limit reached after {solve_time:.2f} seconds, returning best solution found so far.")
    else:
        logger.warning(f"Optimization failed with status {status} after {solve_time:.2f} seconds.")
        # Return default configuration (smallest r for all layers)
        return {layer_name: r_values[0] for layer_name in layer_names}
    
    # Extract the solution
    optimal_r = {}
    total_gain = 0
    total_cost_value = 0
    
    for layer_name in layer_names:
        for i, r in enumerate(r_values):
            if pulp.value(variables[(layer_name, i)]) > 0.5:  # Variable is 1 in the solution
                optimal_r[layer_name] = r_values[i]
                total_gain += gains[layer_name][i]
                total_cost_value += costs[layer_name][i]
    
    # Log results
    elapsed_time = time.time() - start_time
    logger.info(f"Optimization completed in {elapsed_time:.2f} seconds")
    logger.info(f"Optimal solution with total gain: {total_gain:.4f}")
    logger.info(f"Total cost: {total_cost_value:.2f} / {budget:.2f} ({100 * total_cost_value / budget:.1f}%)")
    
    # Log optimal r values
    for layer_name in sorted(optimal_r.keys()):
        r = optimal_r[layer_name]
        i = r_values.index(r)
        logger.info(f"{layer_name}: r={r}, gain={gains[layer_name][i]:.4f}, cost={costs[layer_name][i]:.2f}")
    
    return optimal_r

def get_optimal_r_config(
    model: torch.nn.Module,
    dataloader: torch.utils.data.DataLoader,
    r_values: List[int],
    target_layers: List[str],
    budget: float,
    device: torch.device,
    seed: int = 42
) -> Tuple[Dict[str, int], Dict[str, Any]]:
    """
    Get the optimal r configuration for the given model.
    
    Args:
        model: The model to optimize
        dataloader: Dataloader for importance estimation
        r_values: List of r values to consider
        target_layers: List of layer names to optimize
        budget: Total budget constraint
        device: Device to run the model on
        seed: Random seed for reproducibility
        
    Returns:
        Tuple of (optimal_r, optimization_results) where:
            optimal_r: Dict mapping layer_name to optimal r value
            optimization_results: Dict containing detailed optimization results
    """
    logger.info(f"Computing optimal r configuration with budget {budget}...")
    logger.info(f"Considering r values: {r_values}")
    logger.info(f"Optimizing {len(target_layers)} layers with seed {seed}")
    
    # Prepare optimization data
    gains, costs = prepare_optimization_data(model, dataloader, target_layers, r_values, device)
    
    # Optimize r values with seed for reproducibility
    optimal_r = optimize_r_values(gains, costs, r_values, budget, seed=seed)
    
    # Collect optimization results
    optimization_results = {
        "gains": gains,
        "costs": costs,
        "r_values": r_values,
        "budget": budget,
        "seed": seed,
        "importances": measure_layer_importance(model, dataloader, target_layers, device)
    }
    
    return optimal_r, optimization_results