import math
import random
import typing
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.init as init

def _get_hue_rotation_matrix(rotations: int) -> torch.Tensor:

    assert rotations > 0
    angle = 2 * math.pi / rotations
    cos_a = math.cos(angle)
    sin_a = math.sin(angle)

    # 회전 행렬 공식. 회색 축 (1,1,1)을 중심으로 회전합니다.
    const_a = (1.0 - cos_a) / 3.0
    const_b = math.sqrt(1.0 / 3.0) * sin_a

    R = torch.tensor([
        [cos_a + const_a, const_a - const_b, const_a + const_b],
        [const_a + const_b, cos_a + const_a, const_a - const_b],
        [const_a - const_b, const_a + const_b, cos_a + const_a],
    ], dtype=torch.float32)
    return R



def _rotate_kernel_bank(
        weight: torch.Tensor,
        g_rot: int,
        flip: int = 0,
        align_corners: bool = True
) -> torch.Tensor:
    """
    주어진 기본 커널(weight)을 회전 및 반전시켜 커널 뱅크를 생성합니다.

    Args:
        weight (torch.Tensor): 변환시킬 기본 커널. (C_out, C_in, k, k)
        g_rot (int): 회전 개수 (e.g., 4는 90도씩 회전).
        flip (int): 반전 옵션.
                    - 0: 반전 없음 (Rotation only, C_N 그룹)
                    - 1: Y축 대칭(좌우 반전) 추가 (Dihedral Group, D_N 그룹)
                    - 2: Y축, X축, XY축 대칭 모두 추가 (Rotation x Klein Group)

    Returns:
        torch.Tensor: 변환된 커널 뱅크.
                      - flip=0: (C_out, C_in, g_rot, k, k)
                      - flip=1: (C_out, C_in, g_rot * 2, k, k)
                      - flip=2: (C_out, C_in, g_rot * 4, k, k)
    """
    assert flip in [0, 1], "flip 값은 0, 1 중 하나여야 합니다."
    C_out, C_in, k, _ = weight.shape

    # 1. 회전 커널 생성
    rotated_kernels = []
    if g_rot in [1, 2, 4] and k % 2 == 1:
        rotation_steps = {1: [0], 2: [0, 2], 4: [0, 1, 2, 3]}[g_rot]
        for r in rotation_steps:
            rotated_kernels.append(torch.rot90(weight, k=r, dims=(-2, -1)))
    else:
        B = C_out * C_in
        w_flat = weight.view(B, 1, k, k)
        angles = [2 * math.pi * i / g_rot for i in range(g_rot)]
        for angle in angles:
            c, s = math.cos(angle), math.sin(angle)
            theta = torch.tensor([[c, -s, 0.0], [s, c, 0.0]], dtype=w_flat.dtype, device=w_flat.device)
            theta = theta.unsqueeze(0).expand(B, -1, -1)
            grid = F.affine_grid(theta, size=(B, 1, k, k), align_corners=align_corners)
            w_rotated = F.grid_sample(w_flat, grid, mode='bilinear', padding_mode='zeros', align_corners=align_corners)
            rotated_kernels.append(w_rotated.view(C_out, C_in, k, k))

    all_kernels = list(rotated_kernels)

    # 2. flip 옵션에 따라 대칭 커널 추가
    if flip >= 1:
        # Y축 대칭 (좌우 반전)
        flipped_y = [torch.flip(kernel, dims=[-1]) for kernel in rotated_kernels]
        all_kernels.extend(flipped_y)

    if flip == 2:
        # X축 대칭 (상하 반전)
        flipped_x = [torch.flip(kernel, dims=[-2]) for kernel in rotated_kernels]
        all_kernels.extend(flipped_x)

        # XY축 동시 대칭 (180도 회전과 동일)
        # 이미 y축 대칭된 커널을 x축 대칭하여 생성
        # flipped_xy = [torch.flip(kernel, dims=[-2]) for kernel in flipped_y]
        # all_kernels.extend(flipped_xy)

    return torch.stack(all_kernels, dim=2)



class CRBatchNorm2d(nn.Module):
    def __init__(self, C, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True, sync_stats: bool = True):
        super().__init__()
        self.C = C
        self.sync_stats = sync_stats

        # PyTorch 표준 BatchNorm2d 사용
        # momentum은 PyTorch 기본값(0.1)이 일반적
        self.bn = nn.BatchNorm2d(
            num_features=C,
            eps=eps,
            momentum=momentum,
            affine=affine,
            track_running_stats=track_running_stats
        )

        # SyncBatchNorm 사용 여부
        # if sync_stats and dist.is_available() and dist.is_initialized():
        #     self.bn = nn.SyncBatchNorm.convert_sync_batchnorm(self.bn)

    def forward(self, x):
        assert x.dim() == 6, f"Expected 6D tensor, got {x.dim()}D"
        B, C, Cc, R, H, W = x.shape

        # (B, C, Cc, R, H, W) -> (B*Cc*R, C, H, W)
        # BatchNorm2d는 (N, C, H, W) 형태를 기대함
        x_reshaped = x.permute(0, 2, 3, 1, 4, 5).contiguous()  # (B, Cc, R, C, H, W)
        x_reshaped = x_reshaped.view(B * Cc * R, C, H, W)  # (B*Cc*R, C, H, W)

        # BatchNorm 적용
        x_normed = self.bn(x_reshaped)  # (B*Cc*R, C, H, W)

        # 원래 형태로 복원
        x_normed = x_normed.view(B, Cc, R, C, H, W)  # (B, Cc, R, C, H, W)
        x_normed = x_normed.permute(0, 3, 1, 2, 4, 5)  # (B, C, Cc, R, H, W)

        return x_normed



class EquivariantSpatialPool(nn.Module):
    """
    등변성을 유지하며 공간적 다운샘플링(H, W)을 수행합니다.
    윈도우 내에서 가장 큰 Norm을 가진 특징 벡터를 선택합니다.
    """

    def __init__(self, kernel_size=3, stride=2):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        # Unfold를 사용하여 겹치지 않는 윈도우를 효율적으로 추출합니다.
        self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        입력: [B, C, Cr, Gr, H, W]
        출력: [B, C, Cr, Gr, H_out, W_out]
        """
        B, C, Cr, Gr, H, W = x.shape

        # 1. Unfold를 위해 텐서를 4D [B, Channels, H, W] 형태로 임시 변환
        x_reshaped = x.contiguous().view(B, C * Cr * Gr, H, W)

        # 2. 공간적 윈도우(패치) 추출
        # patches: [B, (C*Cr*Gr) * (k*k), L]
        # L = 패치(윈도우)의 개수 = (H_out * W_out)
        patches = self.unfold(x_reshaped)

        # 3. Norm 계산을 위해 패치를 다시 그룹 구조로 복원
        # patches_structured: [B, C, Cr, Gr, k*k, L]
        patches_structured = patches.view(B, C, Cr, Gr, self.kernel_size * self.kernel_size, -1)

        # 4. 각 패치 내에서 그룹 축에 대한 Norm 계산
        # patch_norms: [B, C, k*k, L]
        patch_norms = torch.linalg.vector_norm(patches_structured, ord=2, dim=(2, 3))

        # 5. 각 윈도우에서 Norm이 가장 큰 픽셀의 인덱스 찾기
        # max_indices: [B, C, 1, L]
        _, max_indices = torch.max(patch_norms, dim=2, keepdim=True)

        # 6. 찾은 인덱스를 사용하여 원본 패치에서 해당 특징 벡터 추출 (gather)
        # gather를 위해 인덱스 텐서의 차원을 원본 패치와 맞춰줍니다.
        # index_expanded: [B, C, Cr, Gr, 1, L]
        idx_expanded = max_indices.unsqueeze(2).unsqueeze(2).expand(B, C, Cr, Gr, 1, -1)

        # selected_features: [B, C, Cr, Gr, 1, L]
        selected_features = torch.gather(patches_structured, dim=4, index=idx_expanded)

        # 7. 최종 출력을 다시 공간적 형태로 복원
        # H_out, W_out 계산
        H_out = (H - self.kernel_size) // self.stride + 1
        W_out = (W - self.kernel_size) // self.stride + 1

        # out: [B, C, Cr, Gr, H_out, W_out]
        out = selected_features.view(B, C, Cr, Gr, H_out, W_out)

        return out


class GlobalAvgPoolCRHW(nn.Module):
    """Global average pool over (color, rot, H, W) -> (B, C)."""
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.mean(dim=(2, 3, 4, 5))  # (B, C)


class ColorGeometricPooling(nn.Module):
    """
    색상(dim=2) 또는 기하학(dim=3) 그룹 축, 혹은 둘 모두에 대해 Max Pooling을 수행.
    풀링된 차원은 사라집니다 (keepdim=False).
    """

    def __init__(self, pool_color: bool = False, pool_geom: bool = False, pool_global: bool = False):
        super().__init__()
        if not pool_color and not pool_geom:
            raise ValueError("pool_color 또는 pool_geom 중 하나는 반드시 True여야 합니다.")
        self.pool_color = pool_color
        self.pool_geom = pool_geom
        self.pool_global = pool_global

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 입력 x: (B, Cout, c_out, g_out, H, W)

        # 색상 축 풀링 수행
        if self.pool_color:
            x = torch.max(x, dim=2, keepdim=False)[0]

        # 기하학 축 풀링 수행
        # 만약 색상 풀링이 먼저 수행되었다면, 기하학 축의 인덱스는 3에서 2로 바뀜
        if self.pool_geom and self.pool_global:
            # 기하학 축(dim=3)에 대해 먼저 풀링. 출력: (B, Cout, c_out, H, W)
            geom_dim = 2 if self.pool_color else 3
            x = torch.max(x, dim=geom_dim, keepdim=False)[0]

            # 완전한 불변성을 위해 공간 축(H, W)에 대해 Global Max Pooling 수행.
            x = torch.max(x, dim=-1, keepdim=False)[0]
            x = torch.max(x, dim=-1, keepdim=False)[0]
        else:
            geom_dim = 2 if self.pool_color else 3
            x = torch.max(x, dim=geom_dim, keepdim=False)[0]


        return x