import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Linear(nn.Linear):
    """
    nn.Linear를 상속받아 Orthogonal Subspace Projection을 수행하는 통합 레이어.
    pretrained=True일 경우 초기화 시 Identity/Zero 매핑을 사용하여
    사전 학습된 가중치를 로드할 때 값을 보존하도록 동작합니다.
    """

    def __init__(self, in_features, out_features, subspace_dim, bias=True, pretrained=False, device=None, dtype=None):
        # 1. nn.Linear 초기화 (self.weight, self.bias 자동 생성)
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

        if out_features % subspace_dim != 0:
            raise ValueError(f"out_features({out_features}) must be divisible by subspace_dim({subspace_dim})")

        self.subspace_dim = subspace_dim
        self.num_subspaces = out_features // subspace_dim
        self.pretrained = pretrained

        # 2. Fixed orthogonal basis Q_total
        # nn.Linear는 weight가 (out, in)이므로 out_features를 사용
        self.register_buffer('Q_total', torch.empty(out_features, out_features))

        # -------------------------------------------------------------------------
        # 3. C Parameter (Rotation within Subspace)
        #    Shape: (subspace_dim, subspace_dim)
        # -------------------------------------------------------------------------
        self.C_param = nn.Parameter(torch.empty(subspace_dim, subspace_dim))

        # 4. 초기화 실행
        self.reset_parameters()

    def reset_parameters(self):
        # 1. Q_total 초기화
        if hasattr(self, 'Q_total'):
            with torch.no_grad():
                if self.pretrained:
                    # [Pretrained Mode] Identity Matrix로 초기화
                    self.Q_total.copy_(torch.eye(self.out_features))
                else:
                    # [Standard Mode] QR decomposition --> orthogonal matrix
                    H = torch.randn(self.out_features, self.out_features)
                    q_complete, _ = torch.linalg.qr(H)
                    self.Q_total.copy_(q_complete)

        # 2. C_param 초기화
        if hasattr(self, 'C_param'):
            if self.pretrained:
                # [Pretrained Mode] 0으로 초기화 -> Cayley Transform 결과가 Identity가 됨
                nn.init.zeros_(self.C_param)
            else:
                # [Standard Mode] 0 근처 랜덤 값 -> Identity에 가까운 Orthogonal Matrix
                nn.init.normal_(self.C_param, mean=0, std=0.01)

        # 나머지는 nn.Linear weight init.
        super().reset_parameters()

        # 4. Bias Init
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def get_orthogonal_C(self):
        """
        [Cayley Transform]
        formula: C = (I - A) @ (I + A)^-1
        where A (skew-symmetric) = C_param - C_param.T
        """
        # 1. Skew-symmetric matrix (A)
        skew_A = self.C_param - self.C_param.t()

        # 2. Identity Matrix (I)
        I = torch.eye(self.subspace_dim, device=skew_A.device, dtype=skew_A.dtype)

        # 3. Cayley Transformation
        numerator = I - skew_A
        denominator = I + skew_A

        # C = inv(I+A) @ (I-A)
        C = torch.linalg.inv(denominator) @ numerator

        return C

    def get_effective_basis(self):
        # 1. orthogonal basis (subspace rotation)
        C = self.get_orthogonal_C()

        # 2. Q_total Reshape (out, num, sub)
        # self.out_features 사용
        Q_reshaped = self.Q_total.view(self.out_features, self.num_subspaces, self.subspace_dim)

        # 3. Broadcasting: Q_new = Q_old @ C
        Q_transformed = torch.matmul(Q_reshaped, C)

        # 4. Flatten back to (out, out)
        Q_effective = Q_transformed.view(self.out_features, self.out_features)

        return Q_effective

    def forward(self, input):
        # 1. Effective Basis 계산
        Q_eff = self.get_effective_basis()

        # 2. Effective Weight 계산
        # self.weight shape: (out_features, in_features)
        # Q_eff shape: (out_features, out_features)
        # W_eff = Q_eff @ W_orig
        weight_effective = torch.mm(Q_eff, self.weight)

        # 3. F.linear를 사용하여 연산 (bias 자동 처리)
        return F.linear(input, weight_effective, self.bias)

    def get_hyperspherical_energy(self, epsilon=1e-6):
        Q_eff = self.get_effective_basis()
        w_final = torch.mm(Q_eff, self.weight)

        w_norm = F.normalize(w_final, p=2, dim=1)
        dist_matrix = torch.cdist(w_norm, w_norm, p=2)

        eye_mask = torch.eye(self.out_features, device=w_norm.device, dtype=torch.bool)

        inverse_dist = 1.0 / (dist_matrix + epsilon)
        inverse_dist = inverse_dist.masked_fill(eye_mask, 0.0)

        energy = inverse_dist.sum()
        return energy

    def get_ideal_energy_lower_bound(self):
        N = self.out_features
        num_pairs = N * (N - 1)
        ideal_energy = num_pairs * (1.0 / math.sqrt(2))
        return ideal_energy

    def get_orthogonality_reg_loss(self):
        # Shape: (K, D, M)
        w_reshaped = self.weight.view(self.num_subspaces, self.subspace_dim, -1)
        w_norm = F.normalize(w_reshaped, p=2, dim=2)

        s_base_blocks = torch.bmm(w_norm, w_norm.transpose(1, 2))

        C = self.get_orthogonal_C()
        g_sub_blocks = torch.matmul(torch.matmul(C, s_base_blocks), C.t())

        I = torch.eye(self.subspace_dim, device=self.weight.device)
        loss = torch.norm(g_sub_blocks - I, p='fro') ** 2

        return loss



class Conv2d(nn.Conv2d):
    """
    nn.Conv2d를 상속받아 Orthogonal Subspace Projection을 수행하는 레이어.
    기존의 self.W 대신 nn.Conv2d의 self.weight를 사용합니다.
    pretrained=True일 경우 초기화 시 Identity/Zero 매핑을 사용하여
    사전 학습된 가중치를 로드할 때 값을 보존하도록 동작합니다.
    """
    def __init__(self, in_channels, out_channels, kernel_size, subspace_dim,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros', device=None, dtype=None, pretrained=False):

        # 1. nn.Conv2d 초기화 (self.weight, self.bias 자동 생성)
        super().__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, bias, padding_mode, device, dtype
        )

        if out_channels % subspace_dim != 0:
            raise ValueError(f"out_channels({out_channels}) must be divisible by subspace_dim({subspace_dim})")

        self.subspace_dim = subspace_dim
        self.num_subspaces = out_channels // subspace_dim
        self.pretrained = pretrained

        # 2. Fixed orthogonal basis Q_total (Defined on Output Channel Space)
        #    Shape: (out_channels, out_channels)
        self.register_buffer('Q_total', torch.empty(out_channels, out_channels, device=device, dtype=dtype))

        # 3. C Parameter (Rotation within Subspace)
        #    Shape: (subspace_dim, subspace_dim)
        self.C_param = nn.Parameter(torch.empty(subspace_dim, subspace_dim, device=device, dtype=dtype))

        # 나머지는 nn.Conv2d weight init.
        self.reset_parameters()

    def reset_parameters(self):
        # 1. Q_total 초기화
        if hasattr(self, 'Q_total'):
            with torch.no_grad():
                if self.pretrained:
                    # [Pretrained Mode] Identity Matrix로 초기화
                    self.Q_total.copy_(torch.eye(self.out_channels, device=self.Q_total.device))
                else:
                    # [Standard Mode] QR decomposition --> orthogonal matrix
                    H = torch.randn(self.out_channels, self.out_channels, device=self.Q_total.device)
                    q_complete, _ = torch.linalg.qr(H)
                    self.Q_total.copy_(q_complete)

        # 2. C_param 초기화
        if hasattr(self, 'C_param'):
            if self.pretrained:
                # [Pretrained Mode] 0으로 초기화 -> Cayley Transform 결과가 Identity가 됨
                nn.init.zeros_(self.C_param)
            else:
                # [Standard Mode] 0 근처 랜덤 값 -> Identity에 가까운 Orthogonal Matrix
                nn.init.normal_(self.C_param, mean=0, std=0.01)

        super().reset_parameters()

    def get_orthogonal_C(self):
        """
        [Cayley Transform]
        formula: C = (I - A) @ (I + A)^-1
        where A (skew-symmetric) = C_param - C_param.T
        """
        # 1. Skew-symmetric matrix (A)
        skew_A = self.C_param - self.C_param.t()

        # 2. Identity Matrix (I)
        I = torch.eye(self.subspace_dim, device=skew_A.device, dtype=skew_A.dtype)

        # 3. Cayley Transformation
        numerator = I - skew_A
        denominator = I + skew_A

        # C = inv(I+A) @ (I-A)
        C = torch.linalg.solve(denominator, numerator)

        return C

    def get_effective_basis(self):
        """ Q_eff = Q_total @ BlockDiag(C) """
        C = self.get_orthogonal_C()

        # (out, num, sub)
        Q_reshaped = self.Q_total.view(self.out_channels, self.num_subspaces, self.subspace_dim)

        # Broadcasting Matmul
        Q_transformed = torch.matmul(Q_reshaped, C)

        # (out, out)
        Q_effective = Q_transformed.view(self.out_channels, self.out_channels)
        return Q_effective

    def get_convolution_weight(self, return_flat=False):
        """
        self.weight: (out, in/g, k, k)
        W_flat: (out, -1)
        W_eff = Q_eff @ W_flat
        """
        # 1. Orthogonal basis
        Q_eff = self.get_effective_basis()

        # 2. Flatten standard Conv2d weight to (out, -1)
        w_flat = self.weight.view(self.out_channels, -1)

        # 3. Apply Projection
        # (out, out) @ (out, flat_in) -> (out, flat_in)
        weight_effective_flat = torch.mm(Q_eff, w_flat)

        if return_flat:
            return weight_effective_flat

        # 4. Reshape back to standard 4D tensor shape
        return weight_effective_flat.view_as(self.weight)

    def forward(self, input):
        # 1. Calculate effective weight
        weight = self.get_convolution_weight()

        # 2. Convolution (using internal implementation for padding handling)
        return self._conv_forward(input, weight, self.bias)

    def get_hyperspherical_energy(self, epsilon=1e-6):
        w_flat = self.get_convolution_weight(return_flat=True)
        w_norm = F.normalize(w_flat, p=2, dim=1)
        dist_matrix = torch.cdist(w_norm, w_norm, p=2)

        eye_mask = torch.eye(self.out_channels, device=w_norm.device, dtype=torch.bool)
        inverse_dist = 1.0 / (dist_matrix + epsilon)
        inverse_dist = inverse_dist.masked_fill(eye_mask, 0.0)

        energy = inverse_dist.sum()
        return energy

    def get_ideal_energy_lower_bound(self):
        N = self.out_channels
        num_pairs = N * (N - 1)
        ideal_energy = num_pairs * (1.0 / math.sqrt(2))
        return ideal_energy

    def get_orthogonality_reg_loss(self):
        # Shape: (K, D, M)
        # self.weight를 flatten해서 사용
        w_reshaped = self.weight.view(self.num_subspaces, self.subspace_dim, -1)
        w_norm = F.normalize(w_reshaped, p=2, dim=2)

        s_base_blocks = torch.bmm(w_norm, w_norm.transpose(1, 2))

        C = self.get_orthogonal_C()
        g_sub_blocks = torch.matmul(torch.matmul(C, s_base_blocks), C.t())

        I = torch.eye(self.subspace_dim, device=self.weight.device)
        loss = torch.norm(g_sub_blocks - I, p='fro') ** 2

        return loss


# ---------------------------------------------------------
# [검증 함수: Projector Orthogonality Check]
# ---------------------------------------------------------
def verify_projector_orthogonality():
    print("=== 부분공간 투영(Projection) 직교성 검증 시작 ===")
    out_dim = 12
    sub_dim = 4

    # 수정된 Linear 클래스 사용
    model = Linear(10, out_dim, sub_dim, pretrained=False)

    model.C_param.data = torch.randn(sub_dim, sub_dim) * 5.0
    print(f"-> C 파라미터를 랜덤 설정했습니다. (Identity 아님)")

    Q_eff = model.get_effective_basis()
    num_subspaces = out_dim // sub_dim
    subspace_bases = []

    for k in range(num_subspaces):
        Q_k = Q_eff[:, k * sub_dim: (k + 1) * sub_dim]
        subspace_bases.append(Q_k)

    projectors = []
    for k, Q_k in enumerate(subspace_bases):
        P_k = torch.mm(Q_k, Q_k.t())
        projectors.append(P_k)

    print("\n[검증 A] 투영 행렬 간의 곱 (P_i @ P_j) 확인")
    is_matrix_ortho = True
    for i in range(num_subspaces):
        for j in range(i + 1, num_subspaces):
            interaction = torch.mm(projectors[i], projectors[j])
            max_val = interaction.abs().max().item()
            print(f"   - Subspace {i} vs {j}: P_{i} P_{j} Max Error = {max_val:.10f}")
            if max_val > 1e-5:
                is_matrix_ortho = False

    if is_matrix_ortho:
        print("결과: >> 성공 << 모든 부분공간의 투영 벡터들은 완벽하게 직교합니다.")
    else:
        print("결과: >> 실패 << 직교성이 깨졌습니다.")


def verify_orthogonality():
    sub_dim = 4
    # 수정된 Linear 클래스 사용
    model = Linear(10, 12, sub_dim, pretrained=False)
    model.C_param.data = torch.randn(sub_dim, sub_dim) * 10

    C = model.get_orthogonal_C()
    c_dot = torch.mm(C.T, C)
    identity = torch.eye(sub_dim)
    diff_c = (c_dot - identity).abs().max().item()

    print(f"1. C^T @ C 와 Identity의 오차: {diff_c:.10f}")

    Q_eff = model.get_effective_basis()
    q_dot = torch.mm(Q_eff.T, Q_eff)
    identity_total = torch.eye(12)
    diff_q = (q_dot - identity_total).abs().max().item()

    print(f"2. Q_eff^T @ Q_eff 와 Identity의 오차: {diff_q:.10f}")


if __name__ == "__main__":
    verify_orthogonality()
    verify_projector_orthogonality()

