import numpy as np
import torch
import math

def count_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def numel_nonzero(p):
    with torch.no_grad():
        return torch.count_nonzero(p)

def count_nonzero_parameters(model):
    """Count the number of non-zero parameters in a model"""
    return sum(numel_nonzero(p) for p in model.parameters() if p.requires_grad)

def calculate_mdl_score(model, val_loss, capacity, precision_bits=7):
    """
    Calculate Minimum Description Length score
    
    MDL = model_complexity + data_complexity
    where:
    - model_complexity is the number of bits needed to encode the model parameters
    - data_complexity is related to the negative log-likelihood of the data given the model
    
    Args:
        model: PyTorch model
        val_loss: Validation loss (MSE)
        precision_bits: Number of bits used to store each parameter
        
    Returns:
        MDL score (lower is better)
    """
    # Count model parameters
    num_params = count_parameters(model)
    # count the non-zero parameters
    num_nonzero_params = count_nonzero_parameters(model)
    
    # Model complexity in bits (number of parameters * bits per parameter)
    # divide the model complexity by the minimum necessary complexity (based on the capacity)
    model_complexity = num_nonzero_params * (math.log2(num_params) * precision_bits) / capacity # normalizing for minimum necessary complexity
    
    # Data complexity (derived from negative log likelihood)
    # For MSE loss, we can derive the negative log likelihood assuming Gaussian noise
    # NLL = n/2 * log(2π * σ²) + 1/(2σ²) * sum((y - ŷ)²)
    # where σ² is the variance of the noise, and sum((y - ŷ)²) is proportional to MSE
    # We can use val_loss as a proxy for MSE, and treat other terms as constants
    data_complexity = val_loss * model.input_dim  # Scale by input dimension
    
    # Total MDL score
    mdl_score = model_complexity + data_complexity# * 1e6  # Scale data complexity to make it comparable with model complexity
    print(f"Model complexity: {model_complexity:.2f}, Data complexity: {data_complexity:.2f}, MDL score: {mdl_score:.2f}")
    
    return mdl_score

def calculate_fitness_score(model, val_loss, capacity, fitness_type="negative_loss", precision_bits=7):
    """
    Calculate fitness score for architecture search
    
    Args:
        model: PyTorch model
        val_loss: Validation loss (MSE)
        fitness_type: Type of fitness score to use ('negative_loss' or 'mdl')
        precision_bits: For MDL, number of bits per parameter (take 8-1 from Ayonrinde et al. 2024)
        
    Returns:
        Fitness score (higher is better)
    """
    if fitness_type == "negative_loss":
        # Simple negative loss (lower loss = higher fitness)
        return -val_loss
    
    elif fitness_type == "mdl":
        # Minimum Description Length (lower MDL = higher fitness)
        mdl_score = calculate_mdl_score(model, val_loss, capacity, precision_bits)
        return -mdl_score  # Negative since lower MDL is better but we want higher fitness
    
    else:
        raise ValueError(f"Unknown fitness type: {fitness_type}")

def get_model_size_info(model):
    """
    Get detailed model size information
    
    Args:
        model: PyTorch model
        
    Returns:
        Dict containing model size information
    """
    num_params = count_parameters(model)
    
    # Calculate size in different units
    bits = num_params * 32  # Assuming float32
    bytes = bits / 8
    kilobytes = bytes / 1024
    megabytes = kilobytes / 1024
    
    return {
        "num_parameters": num_params,
        "size_bits": bits,
        "size_bytes": bytes,
        "size_kb": kilobytes,
        "size_mb": megabytes
    }
