# -*- coding: utf-8 -*-
# Simple config class for normalization parameters
class NormConfig:
    def __init__(self, args=None):
        if args is None:
            self._data = {}
        elif isinstance(args, dict):
            self._data = args.copy()
        elif hasattr(args, '__dict__'):
            self._data = args.__dict__.copy()
        else:
            self._data = {}
    
    def get(self, key, default=None):
        return self._data.get(key, default)
    
    def update(self, **kwargs):
        self._data.update(kwargs)

from .unified_weight_norm import apply_unified_norm, remove_unified_norm

__all__ = ['apply_norm', 'reset_norm', 'remove_norm']


def register_norm(norm_type, norm_class):
    """
    Registers a user-defined normalization class for the apply_norm function.

    This function adds a new entry to the Norm class dict with the key as
    the specified ``norm_type`` and the value as the ``norm_class``.

    Args:
        norm_type (str): The type of normalization to register. This will be used as the key in the Norm class dictionary.
        norm_class (type): The class defining the normalization. This will be used as the value in the Norm class dictionary.

    Example:
        >>> register_norm('custom_norm', CustomNorm)
    """
    _norm_class[norm_type] = norm_class


def register_norm_module(module_class, norm_type, names='weight', dims=0):
    """
    Registers a to-be-normed module for the user-defined normalization class in the `apply_norm` function.

    This function adds a new entry to the _target_modules attribute of the specified normalization class in 
    the _norm_class dictionary. The key is the module class and the value is a tuple containing the attribute name 
    and dimension over which to compute the norm.

    Args:
        module_class (type): Module class to be indexed for the user-defined normalization class.
        norm_type (str): The type of normalization class that the module class should be registered for.
        names (str, optional): Attribute name of ``module_class`` for the normalization to be applied. Default ``'weight'``.
        dims (int, optional): Dimension over which to compute the norm. Default 0.

    Example:
        >>> register_norm_module(Conv2d, 'custom_norm', 'weight', 0)
    """
    _norm_class[norm_type]._target_modules[module_class] = (names, dims)


def _is_skip_prefix(name, prefix_filter_out):
    """
    Helper function to check if a module name starts with any string in the filter_out list.

    Args:
        name (str): Name of the module.
        prefix_filter_out (list of str): List of string prefixes to filter out.

    Returns:
        bool: True if the module name starts with any string in the filter_out list, False otherwise.
    """
    for skip_name in prefix_filter_out:
        if name.startswith(skip_name):
            return True
    
    return False


def _is_skip_name(name, filter_out):
    """
    Helper function to check if a given module name contains any string in the filter_out list.

    Args:
        name (str): Name of the module.
        filter_out (list of str): List of strings to be filtered out.

    Returns:
        bool: True if the module name contains any string in the filter_out list, False otherwise.
    """
    for skip_name in filter_out:
        if skip_name in name:
            return True
    
    return False


def apply_norm(model, norm_type='spectral_norm', prefix_filter_out=None, filter_out=None, args=None, **norm_kwargs):
    """
    통합된 weight normalization 적용 함수
    
    Args:
        model (torch.nn.Module): Model to apply normalization.
        norm_type (str): 'spectral_norm' (default), 'weight_norm', or 'none'
            - 'spectral_norm': Spectral normalization (default)
            - 'weight_norm': Frobenius norm based weight normalization  
            - 'none': No normalization
        prefix_filter_out (list or str, optional): 
            Module name prefixes to skip when applying normalization.
        filter_out (list or str, optional): 
            Module name patterns to skip when applying normalization.
        args (Union[argparse.Namespace, dict, NormConfig, Any]): 
            Configuration containing normalization parameters.
        norm_kwargs: Additional keyword arguments.
            - clip (bool): Enable clipping (default from args)
            - clip_value (float): Clipping threshold (default from args) 
            - learnable_scale (bool): Use learnable g vs fixed target_norm (default True)
            - target_norm (float): Target norm when learnable_scale=False (default 1.0)

    Example:
        >>> # Spectral norm with learnable scale and clipping
        >>> apply_norm(model, 'spectral_norm', clip=True, clip_value=1.0)
        
        >>> # Weight norm with fixed target, no clipping
        >>> apply_norm(model, 'weight_norm', learnable_scale=False, target_norm=0.8)
        
        >>> # No normalization
        >>> apply_norm(model, 'none')
    """
    args = NormConfig(args)
    args.update(**norm_kwargs)

    # norm_type 매핑
    norm_type_mapping = {
        'weight_norm': 'frobenius',
        'spectral_norm': 'spectral', 
        'none': 'none'
    }
    
    unified_norm_type = norm_type_mapping.get(norm_type, norm_type)
    if unified_norm_type == 'none':
        return

    # 필터링 설정
    combined_filter_out = []
    if prefix_filter_out:
        if isinstance(prefix_filter_out, str):
            prefix_filter_out = [prefix_filter_out]
        combined_filter_out.extend(prefix_filter_out)
    
    if filter_out:
        if isinstance(filter_out, str):
            filter_out = [filter_out]
        combined_filter_out.extend(filter_out)

    # 매개변수 추출
    clip = args.get('clip', args.get('norm_clip', False))
    clip_value = args.get('clip_value', args.get('norm_clip_value', 1.0))
    learnable_scale = args.get('learnable_scale', True)
    target_norm = args.get('target_norm', 1.0)
    
    # Clipping이 활성화되면 learnable_scale=False로 설정하고 target_norm을 clip_value로 사용
    if clip:
        learnable_scale = False
        target_norm = clip_value
    
    # 통합 norm 적용
    apply_unified_norm(
        model=model,
        norm_type=unified_norm_type,
        learnable_scale=learnable_scale,
        target_norm=target_norm,
        clip=clip,
        clip_value=clip_value,
        filter_out=combined_filter_out
    )


def reset_norm(model):
    """
    Reset normalization (recompute weights)

    Args:
        model (torch.nn.Module): Model to reset normalization.

    Example:
        >>> reset_norm(model)
    """
    for module in model.modules():
        if hasattr(module, '_unified_norm'):
            # Force recomputation by triggering hook
            if hasattr(module, '_unified_norm'):
                setattr(module, module._unified_norm.name, 
                       module._unified_norm._compute_weight(module))


def remove_norm(model):
    """
    Remove unified normalization from model

    Args:
        model (torch.nn.Module): Model to remove normalization from.

    Example:
        >>> remove_norm(model)
    """
    remove_unified_norm(model)


