# -*- coding: utf-8 -*-
"""
Unified Weight Normalization implementation

Options:
1. norm_type: 'spectral' (default) or 'frobenius'
2. learnable_scale: True (learnable g) or False (fixed target_norm)
3. target_norm: learnable_scale=False일 때 사용할 고정 norm 값
4. clip: clipping activation
5. clip_value: clipping threshold
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class UnifiedWeightNorm:
    """
    통합된 Weight Normalization 구현

    핵심 아이디어:
    weight = scale_factor * normalize(weight_orig, norm_type)

    scale_factor = learnable_g (if learnable_scale=True)
                  or target_norm/current_norm (if learnable_scale=False)
    """

    # supported module types과 default settings
    _target_modules = {
        nn.Conv1d: ('weight', 0),
        nn.Conv2d: ('weight', 0),
        nn.Conv3d: ('weight', 0),
        nn.ConvTranspose1d: ('weight', 1),
        nn.ConvTranspose2d: ('weight', 1),
        nn.ConvTranspose3d: ('weight', 1),
        nn.Linear: ('weight', 0),
    }

    def __init__(self, name, dim, norm_type='spectral', learnable_scale=True,
                 target_norm=1.0, clip=False, clip_value=1.0,
                 n_power_iterations=1, eps=1e-12):
        """
        Args:
            name: parameter name (usually 'weight')
            dim: dimension along which to compute norm
            norm_type: 'spectral' (default) or 'frobenius'
                - 'spectral': 가장 큰 특이값으로 정규화 (더 강한 제약)
                - 'frobenius': Frobenius norm으로 정규화 (전통적인 weight norm)
            learnable_scale:
                - True: learnable g parameter 사용 (g * normalized_weight)
                - False: fixed target_norm 사용 (target_norm * normalized_weight / current_norm)
            target_norm: learnable_scale=False일 때 목표 norm 값
                - 예: target_norm=1.0이면 모든 weight의 norm을 1.0으로 고정
                - 예: target_norm=0.8이면 모든 weight의 norm을 0.8로 고정
            clip: scale factor에 clipping apply 여부
            clip_value: clipping threshold (scale factor의 최대값)
            n_power_iterations: spectral norm 계산시 power iteration 횟수
            eps: numerical stability를 위한 epsilon
        """
        self.name = name
        self.dim = dim
        self.norm_type = norm_type
        self.learnable_scale = learnable_scale
        self.target_norm = target_norm
        self.clip = clip
        self.clip_value = clip_value
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def apply(self, module):
        """to module unified weight normalization apply"""
        weight = getattr(module, self.name)

        # 원본 weight 저장
        delattr(module, self.name)

        if self.norm_type == 'frobenius':
            # Frobenius norm 방식 (전통적인 weight norm)
            module.register_parameter(self.name + '_v', Parameter(weight.data))

            if self.learnable_scale:
                # learnable g parameter 생성
                with torch.no_grad():
                    norm = self._compute_frobenius_norm(weight, self.dim)
                module.register_parameter(self.name + '_g', Parameter(norm))

        elif self.norm_type == 'variance':
            # Variance norm 방식 (Universal Variance Scaling)
            module.register_parameter(self.name + '_orig', Parameter(weight.data))

            # Calculate and store sqrt(m) + sqrt(n)
            m, n = self._get_operator_dimensions(module, weight)
            denominator_base = (m ** 0.5) + (n ** 0.5)
            module.register_buffer('_norm_denominator_base',
                                   torch.tensor(denominator_base, dtype=weight.dtype))

            if self.learnable_scale:
                # learnable g parameter 생성
                g = Parameter(torch.ones(1, dtype=weight.dtype) * self.target_norm)
                module.register_parameter(self.name + '_g', g)

        else:  # spectral
            # Spectral norm 방식
            module.register_parameter(self.name + '_orig', Parameter(weight.data))

            if self.learnable_scale:
                # learnable g parameter 생성
                with torch.no_grad():
                    spectral_norm = self._compute_spectral_norm(weight)
                    if len(weight.shape) == 4:  # Conv
                        g = spectral_norm.view(1, 1, 1, 1)
                    elif len(weight.shape) == 2:  # Linear
                        g = spectral_norm.view(1, 1)
                    else:
                        g = spectral_norm
                module.register_parameter(self.name + '_g', Parameter(g))

            # Power iteration을 위한 u, v 벡터
            weight_mat = weight.view(weight.shape[0], -1)
            u = F.normalize(torch.randn(weight.shape[0]), dim=0, eps=self.eps)
            v = F.normalize(torch.randn(weight_mat.shape[1]), dim=0, eps=self.eps)
            module.register_buffer(self.name + '_u', u)
            module.register_buffer(self.name + '_v', v)

        # Forward hook 등록
        module.register_forward_pre_hook(self._recompute_weight_hook)

        # 초기 weight 설정
        setattr(module, self.name, self._compute_weight(module))

        # normalization 객체 저장 (나중에 제거시 필요)
        module._unified_norm = self

    def _compute_weight(self, module):
        """통합된 weight 계산 로직"""

        if self.norm_type == 'frobenius':
            # Frobenius norm 기반 계산
            weight_v = getattr(module, self.name + '_v')
            current_norm = self._compute_frobenius_norm(weight_v, self.dim)

            if self.learnable_scale:
                # learnable g 사용
                g = getattr(module, self.name + '_g')
                scale_factor = g / current_norm
            else:
                # fixed target_norm 사용
                scale_factor = self.target_norm / current_norm

            normalized_weight = weight_v

        elif self.norm_type == 'variance':
            # Variance norm 기반 계산 (Universal Variance Scaling)
            weight_orig = getattr(module, self.name + '_orig')

            # 1. Compute variance (shape-agnostic)
            v = weight_orig.var().clamp(min=self.eps)

            # 2. Get pre-computed denominator base (√m + √n)
            denominator_base = getattr(module, '_norm_denominator_base')

            # 3. Compute Universal scaling factor: α = 1/((√m + √n) √v)
            alpha = 1.0 / (denominator_base * torch.sqrt(v))

            if self.learnable_scale:
                # learnable g 사용
                g = getattr(module, self.name + '_g')
                scale_factor = alpha * g
            else:
                # fixed target_norm 사용
                scale_factor = alpha * self.target_norm

            normalized_weight = weight_orig

        else:  # spectral
            # Spectral norm 기반 계산
            weight_orig = getattr(module, self.name + '_orig')
            u = getattr(module, self.name + '_u')
            v = getattr(module, self.name + '_v')

            # Use SVD for accurate spectral norm computation
            current_spectral_norm = self._compute_spectral_norm(weight_orig)

            # Update u, v vectors using SVD for accuracy (optional: could use power iteration for efficiency)
            weight_mat = weight_orig.view(weight_orig.shape[0], -1)
            try:
                U_svd, S_svd, V_svd = torch.svd(weight_mat)
                u.data = U_svd[:, 0].clone()
                v.data = V_svd[0, :].clone()
            except:
                # Fallback to power iteration if SVD fails
                current_spectral_norm = self._power_iteration(weight_orig, u, v)

            if self.learnable_scale:
                # learnable g 사용
                g = getattr(module, self.name + '_g')
                scale_factor = g / current_spectral_norm
            else:
                # fixed target_norm 사용
                scale_factor = self.target_norm / current_spectral_norm

            normalized_weight = weight_orig

        # TorchDEQ-style Factor-based Clipping apply
        if self.clip:
            if self.norm_type == 'spectral':
                # TorchDEQ 방식: factor(scale_factor) 자체를 clip_value로 제한
                max_scale_factor = self.clip_value / current_spectral_norm
                scale_factor = torch.minimum(
                    self.clip_value * torch.ones_like(scale_factor),
                    scale_factor
                )
            else:  # frobenius
                # TorchDEQ 방식: factor 자체를 제한하여 최종 norm을 clip_value로 제한
                # final_norm = current_norm * scale_factor ≤ clip_value
                # 따라서 scale_factor ≤ clip_value / current_norm
                max_scale_factor = self.clip_value / current_norm
                scale_factor = torch.minimum(
                    max_scale_factor * torch.ones_like(scale_factor),
                    scale_factor
                )

        return normalized_weight * scale_factor

    def _compute_frobenius_norm(self, weight, dim):
        """Frobenius norm 계산 (기존 weight norm 방식)"""
        if dim is None:
            return weight.norm()
        elif dim == 0:
            output_size = (weight.size(0),) + (1,) * (weight.dim() - 1)
            return weight.contiguous().view(weight.size(0), -1).norm(dim=1).view(*output_size)
        elif dim == weight.dim() - 1:
            output_size = (1,) * (weight.dim() - 1) + (weight.size(-1),)
            return weight.contiguous().view(-1, weight.size(-1)).norm(dim=0).view(*output_size)
        else:
            return weight.transpose(0, dim).contiguous().norm(dim=0).transpose(0, dim)

    def _get_operator_dimensions(self, module, weight):
        """
        Get operator matrix dimensions (m, n) based on layer type
        Following norm_scale.md conventions:
        - Linear: m=d_out, n=d_in
        - Conv2d: m=C_out, n=C_in*k_H*k_W
        """
        if isinstance(module, nn.Linear):
            # Linear / 1×1 Conv / QKV / FFN: m=d_out, n=d_in
            return weight.shape[0], weight.shape[1]

        elif isinstance(module, nn.Conv2d):
            # Conv2d: m=C_out, n=C_in*k_H*k_W
            C_out, C_in, k_H, k_W = weight.shape
            return C_out, C_in * k_H * k_W

        elif isinstance(module, nn.Conv1d):
            # Conv1d: m=C_out, n=C_in*k_W
            C_out, C_in, k_W = weight.shape
            return C_out, C_in * k_W

        elif isinstance(module, nn.Conv3d):
            # Conv3d: m=C_out, n=C_in*k_D*k_H*k_W
            C_out, C_in, k_D, k_H, k_W = weight.shape
            return C_out, C_in * k_D * k_H * k_W

        elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
            # ConvTranspose: swap in/out dimensions
            if isinstance(module, nn.ConvTranspose1d):
                C_in, C_out, k_W = weight.shape
                return C_out, C_in * k_W
            elif isinstance(module, nn.ConvTranspose2d):
                C_in, C_out, k_H, k_W = weight.shape
                return C_out, C_in * k_H * k_W
            else:  # ConvTranspose3d
                C_in, C_out, k_D, k_H, k_W = weight.shape
                return C_out, C_in * k_D * k_H * k_W
        else:
            # Fallback: first dimension as out, rest as in
            return weight.shape[0], weight.numel() // weight.shape[0]

    def _compute_spectral_norm(self, weight):
        """Spectral norm 계산 (SVD 사용)"""
        weight_mat = weight.view(weight.shape[0], -1)
        U, S, V = torch.svd(weight_mat)
        return S.max()

    def _power_iteration(self, weight, u, v):
        """Power iteration으로 spectral norm 근사"""
        weight_mat = weight.view(weight.shape[0], -1)

        with torch.no_grad():
            for _ in range(self.n_power_iterations):
                v = F.normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
                u = F.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)

            if self.n_power_iterations > 0:
                u = u.clone(memory_format=torch.contiguous_format)
                v = v.clone(memory_format=torch.contiguous_format)

        return torch.dot(u, torch.mv(weight_mat, v))

    def _recompute_weight_hook(self, module, input):
        """Forward pass 전에 weight 재계산"""
        setattr(module, self.name, self._compute_weight(module))

    @classmethod
    def apply_to_module(cls, module, name='weight', dim=0, **kwargs):
        """단일 to module apply"""
        norm = cls(name, dim, **kwargs)
        norm.apply(module)
        return norm

    def remove(self, module):
        """Normalization 제거"""
        with torch.no_grad():
            weight = self._compute_weight(module)

        delattr(module, self.name)
        if hasattr(module, self.name + '_v'):
            delattr(module, self.name + '_v')
        if hasattr(module, self.name + '_g'):
            delattr(module, self.name + '_g')
        if hasattr(module, self.name + '_orig'):
            delattr(module, self.name + '_orig')
        if hasattr(module, self.name + '_u'):
            delattr(module, self.name + '_u')
        if hasattr(module, self.name + '_v'):
            delattr(module, self.name + '_v')

        module.register_parameter(self.name, Parameter(weight))

        # Hook 제거는 복잡하므로 warning만 출력
        print("Warning: Forward hooks not removed. Recreate module if needed.")


def apply_unified_norm(model, norm_type='spectral', learnable_scale=True, target_norm=1.0,
                      clip=False, clip_value=1.0, filter_out=None, **kwargs):
    """
    모델에 통합 weight normalization apply

    Args:
        model: apply할 모델
        norm_type: 'spectral' (default) or 'frobenius'
        learnable_scale: True (learnable g) or False (fixed target_norm)
        target_norm: learnable_scale=False일 때 목표 norm 값
        clip: clipping activation
        clip_value: clipping threshold
        filter_out: 제외할 모듈 이름 패턴 리스트
        **kwargs: 추가 매개변수 (n_power_iterations, eps 등)

    Example:
        # Spectral norm + learnable scale + clipping
        apply_unified_norm(model, norm_type='spectral', learnable_scale=True,
                          clip=True, clip_value=1.0)

        # Frobenius norm + fixed target + no clipping
        apply_unified_norm(model, norm_type='frobenius', learnable_scale=False,
                          target_norm=0.8, clip=False)
    """
    if filter_out is None:
        filter_out = []
    elif isinstance(filter_out, str):
        filter_out = [filter_out]

    count = 0
    for name, module in model.named_modules():
        # 필터링 확인
        if any(pattern in name for pattern in filter_out):
            continue

        if type(module) in UnifiedWeightNorm._target_modules:
            param_name, dim = UnifiedWeightNorm._target_modules[type(module)]

            if hasattr(module, param_name):
                norm = UnifiedWeightNorm(
                    param_name, dim, norm_type, learnable_scale, target_norm,
                    clip, clip_value, **kwargs
                )
                norm.apply(module)
                count += 1

    scale_info = "learnable_scale" if learnable_scale else "target_norm={}".format(target_norm)
    clip_info = ", clip≤{}".format(clip_value) if clip else ""
    print("Applied unified {} norm to {} modules ({}{})".format(norm_type, count, scale_info, clip_info))


def remove_unified_norm(model):
    """모델에서 unified normalization 제거"""
    for module in model.modules():
        if hasattr(module, '_unified_norm'):
            module._unified_norm.remove(module)
            delattr(module, '_unified_norm')
    print("Removed unified normalization from model")
