import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class Linear(nn.Linear):
    """
    """

    def __init__(self, in_features, out_features, subspace_dim, bias=True, pretrained=False, device=None, dtype=None):
        # 1. nn.Linear init
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)

        if out_features % subspace_dim != 0:
            raise ValueError(f"out_features({out_features}) must be divisible by subspace_dim({subspace_dim})")

        self.subspace_dim = subspace_dim
        self.num_subspaces = out_features // subspace_dim
        self.pretrained = pretrained

        # 2. Fixed orthogonal basis Q_total
        self.register_buffer('Q_total', torch.empty(out_features, out_features))

        # -------------------------------------------------------------------------
        # 3. C Parameter (Rotation within Subspace)
        #    Shape: (subspace_dim, subspace_dim)
        # -------------------------------------------------------------------------
        self.C_param = nn.Parameter(torch.empty(subspace_dim, subspace_dim))

        # 4. init
        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'Q_total'):
            with torch.no_grad():
                if self.pretrained:
                    # [Pretrained Mode] Identity Matrix로 초기화
                    self.Q_total.copy_(torch.eye(self.out_features))
                else:
                    # [Standard Mode] QR decomposition --> orthogonal matrix
                    H = torch.randn(self.out_features, self.out_features)
                    q_complete, _ = torch.linalg.qr(H)
                    self.Q_total.copy_(q_complete)

        # 2. C_param 초기화
        if hasattr(self, 'C_param'):
            if self.pretrained:
                nn.init.zeros_(self.C_param)
            else:
                nn.init.normal_(self.C_param, mean=0, std=0.01)

        # 나머지는 nn.Linear weight init.
        super().reset_parameters()

        # 4. Bias Init
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def get_orthogonal_C(self):
        """
        [Cayley Transform]
        formula: C = (I - A) @ (I + A)^-1
        where A (skew-symmetric) = C_param - C_param.T
        """
        # 1. Skew-symmetric matrix (A)
        skew_A = self.C_param - self.C_param.t()

        # 2. Identity Matrix (I)
        I = torch.eye(self.subspace_dim, device=skew_A.device, dtype=skew_A.dtype)

        # 3. Cayley Transformation
        numerator = I - skew_A
        denominator = I + skew_A

        # C = inv(I+A) @ (I-A)
        C = torch.linalg.inv(denominator) @ numerator

        return C

    def get_effective_basis(self):
        # 1. orthogonal basis (subspace rotation)
        C = self.get_orthogonal_C()

        # 2. Q_total Reshape (out, num, sub)
        # self.out_features
        Q_reshaped = self.Q_total.view(self.out_features, self.num_subspaces, self.subspace_dim)

        # 3. Broadcasting: Q_new = Q_old @ C
        Q_transformed = torch.matmul(Q_reshaped, C)

        # 4. Flatten back to (out, out)
        Q_effective = Q_transformed.view(self.out_features, self.out_features)

        return Q_effective

    def forward(self, input):
        # 1. Effective Basis
        Q_eff = self.get_effective_basis()

        # 2. Effective Weight
        # self.weight shape: (out_features, in_features)
        # Q_eff shape: (out_features, out_features)
        # W_eff = Q_eff @ W_orig
        weight_effective = torch.mm(Q_eff, self.weight)

        return F.linear(input, weight_effective, self.bias)

    def get_hyperspherical_energy(self, epsilon=1e-6):
        Q_eff = self.get_effective_basis()
        w_final = torch.mm(Q_eff, self.weight)

        w_norm = F.normalize(w_final, p=2, dim=1)
        dist_matrix = torch.cdist(w_norm, w_norm, p=2)

        eye_mask = torch.eye(self.out_features, device=w_norm.device, dtype=torch.bool)

        inverse_dist = 1.0 / (dist_matrix + epsilon)
        inverse_dist = inverse_dist.masked_fill(eye_mask, 0.0)

        energy = inverse_dist.sum()
        return energy

    def get_ideal_energy_lower_bound(self):
        N = self.out_features
        num_pairs = N * (N - 1)
        ideal_energy = num_pairs * (1.0 / math.sqrt(2))
        return ideal_energy

    def get_orthogonality_reg_loss(self):
        # Shape: (K, D, M)
        w_reshaped = self.weight.view(self.num_subspaces, self.subspace_dim, -1)
        w_norm = F.normalize(w_reshaped, p=2, dim=2)

        s_base_blocks = torch.bmm(w_norm, w_norm.transpose(1, 2))

        C = self.get_orthogonal_C()
        g_sub_blocks = torch.matmul(torch.matmul(C, s_base_blocks), C.t())

        I = torch.eye(self.subspace_dim, device=self.weight.device)
        loss = torch.norm(g_sub_blocks - I, p='fro') ** 2

        return loss



class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, subspace_dim,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 padding_mode='zeros', device=None, dtype=None, pretrained=False):

        # 1. nn.Conv2d init.
        super().__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            groups, bias, padding_mode, device, dtype
        )

        if out_channels % subspace_dim != 0:
            raise ValueError(f"out_channels({out_channels}) must be divisible by subspace_dim({subspace_dim})")

        self.subspace_dim = subspace_dim
        self.num_subspaces = out_channels // subspace_dim
        self.pretrained = pretrained

        # 2. Fixed orthogonal basis Q_total (Defined on Output Channel Space)
        #    Shape: (out_channels, out_channels)
        self.register_buffer('Q_total', torch.empty(out_channels, out_channels, device=device, dtype=dtype))

        # 3. C Parameter (Rotation within Subspace)
        #    Shape: (subspace_dim, subspace_dim)
        self.C_param = nn.Parameter(torch.empty(subspace_dim, subspace_dim, device=device, dtype=dtype))

        self.reset_parameters()

    def reset_parameters(self):
        if hasattr(self, 'Q_total'):
            with torch.no_grad():
                if self.pretrained:
                    self.Q_total.copy_(torch.eye(self.out_channels, device=self.Q_total.device))
                else:
                    H = torch.randn(self.out_channels, self.out_channels, device=self.Q_total.device)
                    q_complete, _ = torch.linalg.qr(H)
                    self.Q_total.copy_(q_complete)

        # 2. C_param 초기화
        if hasattr(self, 'C_param'):
            if self.pretrained:
                nn.init.zeros_(self.C_param)
            else:
                nn.init.normal_(self.C_param, mean=0, std=0.01)

        super().reset_parameters()

    def get_orthogonal_C(self):
        """
        [Cayley Transform]
        formula: C = (I - A) @ (I + A)^-1
        where A (skew-symmetric) = C_param - C_param.T
        """
        # 1. Skew-symmetric matrix (A)
        skew_A = self.C_param - self.C_param.t()

        # 2. Identity Matrix (I)
        I = torch.eye(self.subspace_dim, device=skew_A.device, dtype=skew_A.dtype)

        # 3. Cayley Transformation
        numerator = I - skew_A
        denominator = I + skew_A

        # C = inv(I+A) @ (I-A)
        C = torch.linalg.solve(denominator, numerator)

        return C

    def get_effective_basis(self):
        """ Q_eff = Q_total @ BlockDiag(C) """
        C = self.get_orthogonal_C()

        # (out, num, sub)
        Q_reshaped = self.Q_total.view(self.out_channels, self.num_subspaces, self.subspace_dim)

        # Broadcasting Matmul
        Q_transformed = torch.matmul(Q_reshaped, C)

        # (out, out)
        Q_effective = Q_transformed.view(self.out_channels, self.out_channels)
        return Q_effective

    def get_convolution_weight(self, return_flat=False):
        """
        self.weight: (out, in/g, k, k)
        W_flat: (out, -1)
        W_eff = Q_eff @ W_flat
        """
        # 1. Orthogonal basis
        Q_eff = self.get_effective_basis()

        # 2. Flatten standard Conv2d weight to (out, -1)
        w_flat = self.weight.view(self.out_channels, -1)

        # 3. Apply Projection
        # (out, out) @ (out, flat_in) -> (out, flat_in)
        weight_effective_flat = torch.mm(Q_eff, w_flat)

        if return_flat:
            return weight_effective_flat

        # 4. Reshape back to standard 4D tensor shape
        return weight_effective_flat.view_as(self.weight)

    def forward(self, input):
        # 1. Calculate effective weight
        weight = self.get_convolution_weight()

        # 2. Convolution (using internal implementation for padding handling)
        return self._conv_forward(input, weight, self.bias)

    def get_hyperspherical_energy(self, epsilon=1e-6):
        w_flat = self.get_convolution_weight(return_flat=True)
        w_norm = F.normalize(w_flat, p=2, dim=1)
        dist_matrix = torch.cdist(w_norm, w_norm, p=2)

        eye_mask = torch.eye(self.out_channels, device=w_norm.device, dtype=torch.bool)
        inverse_dist = 1.0 / (dist_matrix + epsilon)
        inverse_dist = inverse_dist.masked_fill(eye_mask, 0.0)

        energy = inverse_dist.sum()
        return energy

    def get_ideal_energy_lower_bound(self):
        N = self.out_channels
        num_pairs = N * (N - 1)
        ideal_energy = num_pairs * (1.0 / math.sqrt(2))
        return ideal_energy

    def get_orthogonality_reg_loss(self):
        # Shape: (K, D, M)
        # self.weight를 flatten해서 사용
        w_reshaped = self.weight.view(self.num_subspaces, self.subspace_dim, -1)
        w_norm = F.normalize(w_reshaped, p=2, dim=2)

        s_base_blocks = torch.bmm(w_norm, w_norm.transpose(1, 2))

        C = self.get_orthogonal_C()
        g_sub_blocks = torch.matmul(torch.matmul(C, s_base_blocks), C.t())

        I = torch.eye(self.subspace_dim, device=self.weight.device)
        loss = torch.norm(g_sub_blocks - I, p='fro') ** 2

        return loss



