# -*- coding: utf-8 -*-
"""
Normalization Wrapper - 기존 norm 모듈의 간단한 래퍼

기존 norm/base_norm.py의 기능을 그대로 사용하되,
new_models 아키텍처에서 쉽게 사용할 수 있도록 래핑합니다.
"""

try:
    # 통합된 normalization 모듈 사용
    from .base_norm import apply_norm, reset_norm, remove_norm
    NORM_AVAILABLE = True
except ImportError:
    NORM_AVAILABLE = False
    
    # Fallback functions for when norm is not available
    def apply_norm(model, norm_type='spectral_norm', **kwargs):
        """Fallback: do nothing when norm module unavailable"""
        pass
    
    def reset_norm(model):
        """Fallback: do nothing when norm module unavailable"""
        pass
    
    def remove_norm(model):
        """Fallback: do nothing when norm module unavailable"""
        pass


class NormalizationManager:
    """
    정규화 관리 클래스
    
    기존 norm 모듈의 단순한 래퍼로, new_models에서 사용하기 편하게 만듦.
    실제 정규화 로직은 모두 기존 norm/base_norm.py 사용.
    """
    
    def __init__(self, args=None):
        """
        Args:
            args: 설정 객체 (기존 norm 모듈에 그대로 전달)
        """
        self.args = args
        self.normalization_applied = False
        
    def apply_normalization(self, model, norm_type=None, **kwargs):
        """
        모델에 정규화 적용
        
        Args:
            model: PyTorch 모델
            norm_type: 정규화 타입 ('spectral_norm', 'weight_norm', 'none')
            **kwargs: 추가 파라미터
            
        Returns:
            bool: 정규화 적용 성공 여부
        """
        if not NORM_AVAILABLE:
            return False
            
        # 기본값 설정
        if norm_type is None:
            norm_type = getattr(self.args, 'norm_type', 'spectral_norm')
        
        try:
            # 기존 norm 모듈 그대로 사용
            apply_norm(model, norm_type, args=self.args, **kwargs)
            self.normalization_applied = True
            return True
        except Exception:
            return False
    
    def reset_normalization(self, model):
        """
        정규화 리셋 (기존 norm 모듈 그대로 사용)
        
        Args:
            model: PyTorch 모델
            
        Returns:
            bool: 리셋 성공 여부
        """
        if not NORM_AVAILABLE or not self.normalization_applied:
            return False
            
        try:
            reset_norm(model)
            return True
        except Exception:
            return False
    
    def remove_normalization(self, model):
        """
        정규화 제거 (기존 norm 모듈 그대로 사용)
        
        Args:
            model: PyTorch 모델
            
        Returns:
            bool: 제거 성공 여부
        """
        if not NORM_AVAILABLE:
            return False
            
        try:
            remove_norm(model)
            self.normalization_applied = False
            return True
        except Exception:
            return False
    
    def is_normalization_available(self):
        """정규화 모듈 사용 가능 여부"""
        return NORM_AVAILABLE
    
    def is_applied(self):
        """정규화 적용 상태"""
        return self.normalization_applied


# 편의 함수들 (기존 norm 모듈 그대로 노출)
def apply_model_normalization(model, args=None, norm_type=None, **kwargs):
    """
    모델에 정규화 적용하는 편의 함수
    
    Args:
        model: PyTorch 모델
        args: 설정 객체
        norm_type: 정규화 타입
        **kwargs: 추가 파라미터
        
    Returns:
        bool: 적용 성공 여부
    """
    manager = NormalizationManager(args)
    return manager.apply_normalization(model, norm_type, **kwargs)


def reset_model_normalization(model):
    """
    모델 정규화 리셋하는 편의 함수
    
    Args:
        model: PyTorch 모델
        
    Returns:
        bool: 리셋 성공 여부
    """
    if not NORM_AVAILABLE:
        return False
        
    try:
        reset_norm(model)
        return True
    except Exception:
        return False


# 기존 인터페이스와 동일하게 export
__all__ = [
    'NormalizationManager',
    'apply_model_normalization', 
    'reset_model_normalization',
    'apply_norm',  # 기존 함수 그대로 노출
    'reset_norm',  # 기존 함수 그대로 노출
    'remove_norm'  # 기존 함수 그대로 노출
]