import re
from typing import Any, Callable, Dict, List, Sequence, Union

import flax
import jax
import jax.numpy as jnp
import numpy as np
from flax.core.frozen_dict import FrozenDict

Params = flax.core.FrozenDict[str, Any]

def sum_all_values_in_pytree(pytree) -> float:
    # Flatten the pytree to get all leaves (individual values)
    leaves = jax.tree_util.tree_leaves(pytree)

    # Sum all leaves
    total_sum = sum(jnp.sum(leaf) for leaf in leaves)

    return total_sum

def flatten_dict(d, parent_key="", sep="_"):
    items = {}
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict) or isinstance(v, FrozenDict):
            items.update(flatten_dict(v, new_key, sep=sep))
        else:
            items[new_key] = v
    return items


def add_prefix_to_dict(d: dict, prefix: str = None, sep="/") -> dict:
    new_dict = {}
    for key, value in d.items():
        new_dict[prefix + sep + key] = value
    return new_dict

def add_all_key(d):
    new_dict = {}
    for key, value in d.items():
        if isinstance(value, dict) or isinstance(value, FrozenDict):
            new_dict[key] = add_all_key(value)
            if "kernel" in new_dict[key] and "bias" in new_dict[key]:
                kernel_norm = jnp.square(new_dict[key]["kernel"])
                bias_norm = jnp.square(new_dict[key]["bias"])
                new_dict[key] = jnp.sqrt(kernel_norm + bias_norm)
        else:
            new_dict[key] = jnp.linalg.norm(value)
    return new_dict

def get_dormant_ratio(
    activations: Dict[str, List[jnp.ndarray]], prefix: str, tau: float = 0.1
) -> jnp.ndarray:
    """
    Compute the dormant mask for a given set of activations.

    Args:
        activations: A dictionary of activations.
        prefix: A string prefix for naming.
        tau: The threshold for determining dormancy.

    Returns:
        A dictionary of dormancy ratios for each layer and the total.

    Source : https://github.com/timoklein/redo/blob/dcaeff1c6afd0f1615a21da5beda870487b2ed15/src/redo.py#L215
    """
    key = "dormant" if tau > 0.0 else "zeroactiv"
    ratios = {}
    total_activs = []

    for sub_layer_name, activs in list(activations.items()):
        if 'Dense' in sub_layer_name and 'predictor' not in sub_layer_name:
            layer_name = f"{prefix}_{sub_layer_name}"
            activs = jnp.array(activs['__call__'])
            
            # Handle dimensionality cases
            if activs.ndim == 4:
                if activs.shape[0] != 1:
                    raise ValueError(f"4D input must have first dimension=1, got shape {activs.shape}")
                activs = jnp.squeeze(activs, axis=0)

            if activs.ndim == 3:
                if activs.shape[0] == 1:
                    activs = jnp.squeeze(activs, axis=0)

            if activs.ndim == 3:
                # For 3D input, compute dormant ratio for each slice and take mean
                def compute_dormant_ratio(slice):
                    # Taking the mean here conforms to the expectation under D in the main paper's formula
                    score = jnp.abs(slice).mean(axis=0)
                    normalized_score = score / (score.mean() + 1e-9)
                    if tau > 0.0:
                        layer_mask = jnp.where(normalized_score <= tau, 1, 0)
                    else:
                        layer_mask = jnp.where(
                            jnp.isclose(normalized_score, jnp.zeros_like(normalized_score)), 1, 0
                        )
                    return layer_mask
                
                # Compute dormant ratio for each slice
                layer_masks = jax.vmap(compute_dormant_ratio)(activs)
                # Take mean across all slices
                layer_mask = jnp.mean(layer_masks, axis=0)
            elif activs.ndim == 2:
                # For 2D input, compute directly
                score = jnp.abs(activs).mean(axis=0)
                normalized_score = score / (score.mean() + 1e-9)
                if tau > 0.0:
                    layer_mask = jnp.where(normalized_score <= tau, 1, 0)
                else:
                    layer_mask = jnp.where(
                        jnp.isclose(normalized_score, jnp.zeros_like(normalized_score)), 1, 0
                    )
            else:
                raise ValueError(f"Input must be 2D, 3D, or 4D with first dimension=1, got shape {activs.shape}")

            ratios[f"{prefix}/{key}_{layer_name}"] = (
                jnp.sum(layer_mask) / layer_mask.size
            ) * 100
            total_activs.append(layer_mask)

    # aggregated mask of entire network
    total_mask = jnp.concatenate(total_activs)

    ratios[f"{prefix}/{key}_total"] = (jnp.sum(total_mask) / total_mask.size) * 100

    return ratios

def get_weight_norm(
    param_dict: Params,
    prefix: str = None,
) -> dict:
    """
    param_dict is a frozen dictionary which contains the gradients/values of each individual parameter

    Return:
        param gradient/value norm dictionary
        (Caution : norm values for vmapped functions (multi-head Q-networks) are summed to a single value)
    """
    param_norm_dict = jax.tree_util.tree_map(lambda x: jnp.linalg.norm(x), param_dict)

    updated_params = add_all_key(param_norm_dict)
    squared_param_norm_dict = jax.tree_util.tree_map(
        lambda x: jnp.square(x), param_norm_dict
    )
    updated_params["total"] = jnp.sqrt(
        sum_all_values_in_pytree(squared_param_norm_dict)
    )

    return add_prefix_to_dict(
        flatten_dict(updated_params), prefix + "/weightnorm", sep="_"
    )

def get_srank(matrix, thershold=0.01):
    matrix = jnp.array(matrix)
    
    # Handle 4D case with first dim=1
    if matrix.ndim == 4:
        if matrix.shape[0] != 1:
            raise ValueError(f"4D input must have first dimension=1, got shape {matrix.shape}")
        matrix = jnp.squeeze(matrix, axis=0)

    if matrix.ndim == 3:
        if matrix.shape[0] == 1:
            matrix = jnp.squeeze(matrix, axis=0)    
    # Handle 3D case - compute rank for each slice and take mean
    if matrix.ndim == 3:
        n = matrix.shape[0]
        def compute_srank(slice):
            # Reshape to 2D for SVD
            slice_2d = slice.reshape(-1, slice.shape[-1])
            singular_vals = jnp.linalg.svd(
                slice_2d, full_matrices=False, compute_uv=False)
            # SVD returns singular values in descending order
            sum_sing = jnp.sum(singular_vals)
            target_ratio = 1.0 - thershold
            
            def body_fn(i, carry):
                accumulated_sum, k = carry
                current_ratio = accumulated_sum / sum_sing
                should_continue = current_ratio < target_ratio
                new_accumulated_sum = jax.lax.cond(
                    should_continue,
                    lambda: accumulated_sum + singular_vals[i],
                    lambda: accumulated_sum
                )
                new_k = jax.lax.cond(
                    should_continue,
                    lambda: k + 1,
                    lambda: k
                )
                return (new_accumulated_sum, new_k)
            
            init_carry = (jnp.array(0.0), jnp.array(0))
            final_sum, k = jax.lax.fori_loop(0, len(singular_vals), body_fn, init_carry)
            return k
        
        # Compute srank for each slice and take mean
        sranks = jax.vmap(compute_srank)(matrix)
        mean_rank = jnp.mean(sranks)
        return mean_rank
    
    # Handle 2D case
    if matrix.ndim == 2:
        singular_vals = jnp.linalg.svd(
            matrix, full_matrices=False, compute_uv=False)
        # SVD returns singular values in descending order
        sum_sing = jnp.sum(singular_vals)
        target_ratio = 1.0 - thershold
        
        def body_fn(i, carry):
            accumulated_sum, k = carry
            current_ratio = accumulated_sum / sum_sing
            should_continue = current_ratio < target_ratio
            new_accumulated_sum = jax.lax.cond(
                should_continue,
                lambda: accumulated_sum + singular_vals[i],
                lambda: accumulated_sum
            )
            new_k = jax.lax.cond(
                should_continue,
                lambda: k + 1,
                lambda: k
            )
            return (new_accumulated_sum, new_k)
        
        init_carry = (jnp.array(0.0), jnp.array(0))
        final_sum, k = jax.lax.fori_loop(0, len(singular_vals), body_fn, init_carry)
        return k
    
    # Raise error for other cases
    raise ValueError(f"Input must be 2D, 3D, or 4D with first dimension=1, got shape {matrix.shape}")