import jax
import jax.numpy as jnp
from typing import Any, Callable, Dict, Optional, Tuple, Union
import functools

def topk_sparsification(params: Any, k_percent: float = 0.1) -> Any:
    """Apply top-k sparsification to a pytree of parameters.
    
    Args:
        params: A pytree of parameters.
        k_percent: The percentage of top values to keep (between 0 and 1).
    
    Returns:
        A pytree with the same structure as params, but with only the top k% values kept.
    """
    def _topk_array(x):
        # Flatten the array
        x_flat = x.reshape(-1)
        # Calculate the number of elements to keep
        k = max(1, int(k_percent * x_flat.size))
        # Get the threshold value for top-k
        threshold = jnp.sort(jnp.abs(x_flat))[-k]
        # Create a mask for values above the threshold
        mask = jnp.abs(x) >= threshold
        # Apply the mask
        return x * mask
    
    return jax.tree_util.tree_map(_topk_array, params)

def quantization(params: Any, bits: int = 8, scale_per_tensor: bool = True) -> Any:
    """Quantize a pytree of parameters to a lower bit representation.
    
    Args:
        params: A pytree of parameters.
        bits: Number of bits to quantize to (1-32).
        scale_per_tensor: Whether to use a single scale for the entire tensor (True)
                          or per-channel scales (False).
    
    Returns:
        A pytree with the same structure as params, but with quantized values.
    """
    def _quantize_array(x):
        # Determine the range of values
        if scale_per_tensor:
            x_min = jnp.min(x)
            x_max = jnp.max(x)
        else:
            # Per-channel quantization (assuming last dim is channels)
            x_min = jnp.min(x, axis=tuple(range(x.ndim - 1)), keepdims=True)
            x_max = jnp.max(x, axis=tuple(range(x.ndim - 1)), keepdims=True)
        
        # Calculate the scale and zero point
        scale = (x_max - x_min) / (2**bits - 1)
        # Avoid division by zero
        scale = jnp.where(scale == 0, 1.0, scale)
        
        # Quantize
        x_quant = jnp.round((x - x_min) / scale)
        # Clip to ensure values are within range
        x_quant = jnp.clip(x_quant, 0, 2**bits - 1)
        # Dequantize (to simulate the effect of quantization)
        x_dequant = x_quant * scale + x_min
        
        return x_dequant
    
    return jax.tree_util.tree_map(_quantize_array, params)

def random_sparsification(params: Any, keep_percent: float = 0.1, key: Optional[jax.random.PRNGKey] = None) -> Any:
    """Apply random sparsification to a pytree of parameters.
    
    Args:
        params: A pytree of parameters.
        keep_percent: The percentage of values to randomly keep (between 0 and 1).
        key: A PRNG key for random number generation. If None, a new one will be created.
    
    Returns:
        A pytree with the same structure as params, but with randomly selected values kept.
    """
    if key is None:
        key = jax.random.PRNGKey(0)
    
    def _random_sparsify_array(x, subkey):
        # Create a random mask
        mask = jax.random.uniform(subkey, shape=x.shape) < keep_percent
        # Apply scaling to preserve the expected sum
        scaling = 1.0 / keep_percent
        return jnp.where(mask, x * scaling, 0.0)
    
    # Create a different key for each leaf in the pytree
    keys = jax.random.split(key, jax.tree_util.tree_leaves(params).__len__())
    keys_iter = iter(keys)
    
    # Apply random sparsification to each leaf
    return jax.tree_util.tree_map(lambda x: _random_sparsify_array(x, next(keys_iter)), params)

def _single_cocktail_compression(
    params: Any, 
    topk_percent: float = 0.3, 
    quantize_bits: int = 8, 
    random_percent: float = 0.0,
    compression_order: str = "random,topk,quantize",
    key: Optional[jax.random.PRNGKey] = None
) -> Any:
    """Apply a combination of compression techniques as in CocktailSGD to a single instance.
    
    Args:
        params: A pytree of parameters.
        topk_percent: The percentage of top values to keep (between 0 and 1).
        quantize_bits: Number of bits to quantize to (1-32).
        random_percent: The percentage of values to randomly keep after topk (between 0 and 1).
        compression_order: The order in which to apply the compression techniques.
        key: A PRNG key for random number generation. If None, a new one will be created.
    
    Returns:
        A pytree with the same structure as params, but with compression applied.
    """
    # Initialize random key if not provided
    if key is None:
        key = jax.random.PRNGKey(0)

    compression_order = compression_order.split(",")
    
    compressed_params = params
    
    for technique in compression_order:
        if technique == "topk" and topk_percent < 1.0:
            compressed_params = topk_sparsification(compressed_params, topk_percent)
        elif technique == "quantize" and quantize_bits < 32:
            compressed_params = quantization(compressed_params, quantize_bits)
        elif technique == "random" and random_percent > 0.0:
            key, subkey = jax.random.split(key)
            compressed_params = random_sparsification(compressed_params, random_percent, subkey)
    
    return compressed_params

def cocktail_compression(
    params: Any, 
    topk_percent: float = 0.3, 
    quantize_bits: int = 8, 
    random_percent: float = 0.0,
    compression_order: str = "random,topk,quantize",
    key: Optional[jax.random.PRNGKey] = None
) -> Any:
    """Apply a combination of compression techniques as in CocktailSGD.
    This function is vmapped over the first dimension of params.
    
    Args:
        params: A pytree of parameters with a batch dimension.
        topk_percent: The percentage of top values to keep (between 0 and 1).
        quantize_bits: Number of bits to quantize to (1-32).
        random_percent: The percentage of values to randomly keep after topk (between 0 and 1).
        compression_order: The order in which to apply the compression techniques.
        key: A PRNG key for random number generation. If None, a new one will be created.
    
    Returns:
        A pytree with the same structure as params, but with compression applied.
    """
    # Initialize random key if not provided
    if key is None:
        key = jax.random.PRNGKey(0)
    
    # Create batch keys for each element in the batch
    batch_size = jax.tree_util.tree_leaves(params)[0].shape[0]
    batch_keys = jax.random.split(key, batch_size)
    
    # Create a vmapped version of the single compression function
    vmapped_compression = jax.vmap(
        _single_cocktail_compression,
        in_axes=(0, None, None, None, None, 0),
        out_axes=0
    )
    
    return vmapped_compression(
        params, 
        topk_percent, 
        quantize_bits, 
        random_percent, 
        compression_order,
        batch_keys
    )

if __name__ == "__main__":
    # Set random seed for reproducibility
    key = jax.random.PRNGKey(42)
    
    # Create a sample pytree with the specified structure
    sample_params = {
        'mlp/~/linear_0': {
            'b': jax.random.normal(key, (2, 128)),
            'w': jax.random.normal(jax.random.split(key)[0], (2, 3072, 128))
        },
        'mlp/~/linear_1': {
            'b': jax.random.normal(jax.random.split(key)[1], (2, 128)),
            'w': jax.random.normal(jax.random.split(key)[0], (2, 128, 128))
        },
        'mlp/~/linear_2': {
            'b': jax.random.normal(jax.random.split(key)[1], (2, 128)),
            'w': jax.random.normal(jax.random.split(key)[0], (2, 128, 128))
        },
        'mlp/~/linear_3': {
            'b': jax.random.normal(jax.random.split(key)[1], (2, 10)),
            'w': jax.random.normal(jax.random.split(key)[0], (2, 128, 10))
        }
    }
    
    # Test different compression techniques
    print("Testing cocktail compression with different settings:")
    
    # Test 1: Top-k only
    compressed_params = cocktail_compression(
        sample_params,
        topk_percent=0.3,
        compression_order="topk",
        key=jax.random.PRNGKey(0)
    )
    
    # Test 2: Quantization only
    compressed_params = cocktail_compression(
        sample_params,
        quantize_bits=8,
        compression_order="quantize",
        key=jax.random.PRNGKey(1)
    )
    
    # Test 3: Random sparsification only
    compressed_params = cocktail_compression(
        sample_params,
        random_percent=0.2,
        compression_order="random",
        key=jax.random.PRNGKey(2)
    )
    
    # Test 4: Full cocktail (all techniques)
    compressed_params = cocktail_compression(
        sample_params,
        topk_percent=0.3,
        quantize_bits=8,
        random_percent=0.2,
        compression_order="topk,random,quantize",
        key=jax.random.PRNGKey(3)
    )
    
    # Print structure of the compressed parameters
    print("\nCompressed parameters structure:")
    jax.tree_util.tree_map(lambda x: print(f"{x.shape}, {x.dtype}"), compressed_params)
    
    # Calculate compression statistics
    original_size = sum(x.size * x.dtype.itemsize for x in jax.tree_util.tree_leaves(sample_params))
    print(f"\nOriginal parameters size: {original_size / 1024:.2f} KB")
    
    # Count non-zero elements in compressed params (for sparsity measurement)
    non_zeros = sum(jnp.count_nonzero(x) for x in jax.tree_util.tree_leaves(compressed_params))
    total_elements = sum(x.size for x in jax.tree_util.tree_leaves(compressed_params))
    sparsity = 1.0 - (non_zeros / total_elements)
    print(f"Sparsity achieved: {sparsity:.2%}")
