import math
import torch
import torch.nn.functional as F
from torch import nn
import torch.nn.init as init

from src.models.jcgel.layers.utils import _get_hue_rotation_matrix, _rotate_kernel_bank

class CNNTrasnposedLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2):
        super(CNNTrasnposedLayer, self).__init__()
        self.conv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1
        )
        # self.batch = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(True)

    def forward(self, input):
        output = self.conv(input)
        # output = self.batch(output)
        output = self.relu(output)
        return output

class JCGConv2d(nn.Module):
    """
    순수 PyTorch로 구현된 색상(Color) 및 기하(Geometric) 등변 합성곱 레이어.
    Version 2 (e2cnn 기반)의 로직을 그대로 따르며, Lifting과 Group 연산을 명확히 구분합니다.

    - is_lifting=True (입력층): Z2(일반 이미지) -> C_color × G_geom (그룹 표현)
    - is_lifting=False (은닉층): C_color × G_geom -> C_color × G_geom (그룹 표현 위에서의 연산)
    """

    def __init__(self,
                 # 채널 및 커널 크기
                 in_channels: int,
                 out_channels: int,
                 kernel_size: int,
                 # 색상 그룹 파라미터
                 in_c_rotations: int,
                 out_c_rotations: int,
                 # 기하 그룹 파라미터
                 g_rotations: int,
                 n_flip = 0,
                 # 레이어 타입 및 기타 옵션
                 is_lifting: bool = False,
                 bias: bool = True,
                 stride: int = 1,
                 padding: int = 0,
                 init_scale: float = 1.0,
                 temperature: float = 0.01,
                 normalization: bool = False,
                 soft: bool = False,
                 init_method : str = 'xavier_uniform'):
        super().__init__()

        self.is_lifting = is_lifting
        self.Cin = in_channels
        self.Cout = out_channels
        self.k = kernel_size
        self.in_c = in_c_rotations
        self.out_c = out_c_rotations
        self.g_rot = g_rotations
        self.n_flip = n_flip
        self.stride = stride
        self.padding = padding
        self.use_bias = bias
        self.init_method = init_method  # 초기화 방법 저장
        self.init_scale = init_scale
        self.temp = temperature
        self.normalization = normalization
        self.soft = soft

        if self.soft:
            if self.n_flip == 0:
                self.weighted_rotkernel = nn.Parameter(torch.rand(self.g_rot))
            else:
                self.weighted_rotkernel = nn.Parameter(torch.rand(self.g_rot * (self.n_flip+1)))



    # g_rotation == 1인 경우 고려 X.
        if self.is_lifting:
            assert self.in_c == 1 and self.Cin == 3, \
                "Lifting layer must have in_c_rotations=1 and in_channels=3 (RGB)."
            self.weight = nn.Parameter(torch.Tensor(self.Cout, self.Cin, self.k, self.k))
            if self.use_bias:
                self.bias = nn.Parameter(torch.Tensor(self.Cout))
            else:
                self.register_parameter('bias', None)

            R_base = _get_hue_rotation_matrix(self.out_c)
            R_powers = torch.stack([torch.matrix_power(R_base, i) for i in range(self.out_c)])
            self.register_buffer("color_reps", R_powers)

            # self.lift_scale = nn.Parameter(torch.ones(self.Cout*self.out_c*self.g_rot,))
            # self.lift_scale = nn.Parameter(torch.ones(1,))

        else:
            assert self.in_c == self.out_c, "For group conv, in_c must equal out_c."
            self.weights = nn.ParameterList([
                nn.Parameter(torch.Tensor(self.Cout, self.Cin, self.k, self.k))
                for _ in range(self.in_c)
            ])
            if self.use_bias:
                self.biases = nn.ParameterList([
                    nn.Parameter(torch.Tensor(self.Cout)) for _ in range(self.in_c)
                ])
            else:
                self.register_parameter('biases', None)
        self.reset_parameters()


    def reset_parameters(self) -> None:
        """
        He(Kaiming) 초기화.
        - 리프팅 레이어: fan_in = Cin * k * k
        - 그룹 레이어:   fan_in = Cin * in_c * g_rot * k * k
        - 비편향(가중치만 He), bias는 0으로 초기화(Conv 기본 관행)
        - self.init_scale로 (필요 시) 추가 스케일 조정 가능(기본 1.0 권장)
        """

        # 비선형 정보 (기본 relu)
        nonlin = getattr(self, "nonlinearity", "relu")
        a = float(getattr(self, "negative_slope", 0.0)) if nonlin == "leaky_relu" else 0.0

        def he_std(fan_in: int) -> float:
            gain = init.calculate_gain(nonlin, a)
            return (gain / math.sqrt(fan_in)) * float(getattr(self, "init_scale", 1.0))

        with torch.no_grad():
            if self.is_lifting:
                # Lifting: Z^2 -> G 공간으로 보낼 때 입력 fan만 고려
                fan_in = int(self.Cin) * int(self.k) * int(self.k)
                std = he_std(fan_in)
                # self.weight.normal_(mean=0.0, std=std)
                self.weight.data.uniform_(-std, std)
                if getattr(self, "use_bias", False):
                    self.bias.zero_()
            else:
                # Group layer: 입력 섬유 차원(in_c)과 그룹의 크기(g_rot)까지 입력 팬에 포함
                # (합성곱 시 이 차원들을 따라 합산되므로 fan_in에 들어가는 것이 He의 기본 가정과 일치)
                fan_in = int(self.Cin) * int(self.in_c) * int(self.g_rot) * int(self.k) * int(self.k)
                std = he_std(fan_in)

                # self.weights: Iterable[Tensor]
                for w in self.weights:
                    w.uniform_(-std, std) # .normal_(mean=0.0, std=std)

                if getattr(self, "use_bias", False):
                    for b in self.biases:
                        b.zero_()



    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # if self.g_rot == 1:
        #     # g_rot=1이면 회전 등변성이 없으므로, CEConv2d와 동일한 로직 수행
        #     return self._forward_color_only(x)
        #
        # # g_rot > 1이면 기존 CGEConv2d 로직 수행
        # if self.is_lifting:
        #     return self._forward_lifting(x)
        # else:
        #     return self._forward_group(x)

        if self.is_lifting:
            return self._forward_lifting(x)
        else:
            return self._forward_group(x)
            # return self._forward_group_relative(x)

    def _forward_lifting(self, x: torch.Tensor) -> torch.Tensor:
        w_geom = _rotate_kernel_bank(self.weight, self.g_rot)

        # Todo: must add temperature
        if self.soft:
            weighted_kernel = F.softmax(self.weighted_rotkernel / self.temp, dim=-1)
            # weighted_kernel = F.softmax(self.weighted_rotkernel, dim=-1) # * self.g_rot
            weighted_kernel = weighted_kernel / weighted_kernel.max()
            weighted_kernel = weighted_kernel.view(1, 1, -1, 1, 1)
            w_geom = w_geom * weighted_kernel

        w_full = torch.einsum('oij, cjdkl -> coidkl', self.color_reps, w_geom)

        w_conv = w_full.permute(1, 0, 3, 2, 4, 5).contiguous()
        w_conv = w_conv.view(self.out_c * self.Cout * self.g_rot, 3, self.k, self.k)
        # w_conv = _safe_weightnorm(w_conv, self.lift_scale)

        y = F.conv2d(x, w_conv, bias=None, stride=self.stride, padding=self.padding)
        B, _, H_out, W_out = y.shape

        y = y.view(B, self.out_c, self.Cout, self.g_rot, H_out, W_out)
        y = y.permute(0, 2, 1, 3, 4, 5)

        if self.use_bias:
            bias_view = self.bias.view(1, self.Cout, 1, 1, 1, 1)
            y = y + bias_view

        return y

    # ### 수정된 부분: Version2의 정확한 로직 재현 ###
    # _forward_group_forloop: 연산 속도 개선.
    def _forward_group(self, x: torch.Tensor) -> torch.Tensor:
        """
        Version2의 E2ColorConv._forward_group() 로직을 정확히 재현합니다.
        핵심: 각 출력 색상별로 별도의 "가상 e2cnn 컨볼루션"을 시뮬레이션

        수정 사항:
        1. Version2와 완전히 동일한 구조: 각 delta_c별 가중치를 이용한 separate convolution
        2. 입력 텐서 변환을 Version2와 일치: x_ic를 올바른 e2cnn 형식으로 flattenn
        3. 그룹 합성곱 대신 일반 합성곱으로 처리 (Version2에서 e2cnn이 내부적으로 수행하는 방식)

        Reference:
        - Version2 E2ColorConv._forward_group() 메서드의 정확한 복제
        - e2cnn.R2Conv의 내부 동작 방식 모방
        """

        B, Cin, c_in, g_in, H, W = x.shape

        # 1. 모든 delta_c 커널들에 대해 기하학적 회전을 미리 적용
        stacked_weights = torch.stack(list(self.weights), dim=0)  # (c_in, Cout, Cin, k, k)
        w_geom_bank = _rotate_kernel_bank(stacked_weights.view(-1, Cin, self.k, self.k), self.g_rot)
        w_geom_bank = w_geom_bank.view(c_in, self.Cout, Cin, self.g_rot, self.k, self.k)

        # Todo: must add temperature
        if self.soft:
            weighted_kernel = F.softmax(self.weighted_rotkernel / self.temp, dim=-1)
            # weighted_kernel = F.softmax(self.weighted_rotkernel, dim=-1) # * self.g_rot
            weighted_kernel = weighted_kernel / weighted_kernel.max()
            weighted_kernel = weighted_kernel.view(1, 1, -1, 1, 1)
            w_geom_bank = w_geom_bank * weighted_kernel

        # 2. 합성곱을 위해 커널 재배열: (c_in, Cout*g_rot, Cin*g_in, k, k)
        w_bank = w_geom_bank.permute(0, 1, 3, 2, 4, 5).reshape(c_in, self.Cout * self.g_rot, Cin, self.k, self.k)
        w_bank = w_bank.unsqueeze(3).expand(-1, -1, -1, g_in, -1, -1).reshape(c_in, self.Cout * self.g_rot,
                                                                              Cin * g_in, self.k, self.k)

        # 3. 입력 텐서를 펼침: (B, c_in*Cin*g_in, H, W)
        x_flat = x.permute(0, 2, 1, 3, 4, 5).reshape(B, c_in * Cin * g_in, H, W)# .contiguous()

        output_colors = []
        # Batch Norm or ResNet(?)과 같이 쓰면  1.0
        # 단독으로 사용하면 or CNN(?) 1.0 / math.sqrt(self.out_c*self.in_c*self.g_rot)
        normalization_factor = 1.0 / math.sqrt(self.out_c*self.in_c*self.g_rot)

        # 4. 안쪽 루프(ic)를 벡터화. 각 출력 색상 채널(oc)에 대해 한 번의 conv2d만 수행
        for oc in range(self.out_c):
            # oc에 필요한 커널들을 순서대로 조합하여 하나의 큰 커널 생성
            w_rolled = torch.roll(w_bank, shifts=-oc, dims=0)  # (c_in, Cout*g_rot, Cin*g_in, k, k)
            w_oc = w_rolled.permute(1, 0, 2, 3, 4).reshape(self.Cout * self.g_rot, c_in * Cin * g_in, self.k,
                                                           self.k)

            # g_oc = self.group_scale[oc]
            # w_oc = _safe_weightnorm(w_oc, g_oc)
            # 단일 conv2d로 해당 oc의 결과 계산
            if self.normalization:
                y_oc = F.conv2d(x_flat, w_oc * normalization_factor, bias=None, stride=self.stride, padding=self.padding)
            else:
                y_oc = F.conv2d(x_flat, w_oc, bias=None, stride=self.stride, padding=self.padding)


            if self.use_bias:
                b_rolled = torch.roll(torch.stack(list(self.biases)), shifts=-oc, dims=0)
                b_oc = b_rolled.sum(dim=0)  # (Cout,)
                b_view = b_oc.view(1, self.Cout, 1, 1).expand(-1, -1, self.g_rot, -1).reshape(1,
                                                                                              self.Cout * self.g_rot,
                                                                                              1, 1)
                y_oc += b_view

            output_colors.append(y_oc)

        y = torch.stack(output_colors, dim=1)
        # y *= normalization_factor

        # 5. 최종 형태로 재구성
        H_out, W_out = y.shape[-2:]
        y = y.view(B, self.out_c, self.Cout, self.g_rot, H_out, W_out).permute(0, 2, 1, 3, 4, 5)
        return y

    # rotation
    def _forward_group_relative(self, x: torch.Tensor) -> torch.Tensor:
        """
        Relative indexing 기반 Group Convolution.
        Color (H_n)와 Geometry (C_k or D_k) 모두에 대해 Δ index (relative shift)를 roll로 구현.
        """

        B, Cin, c_in, g_in, H, W = x.shape  # x: (B, Cin, c_in, g_in, H, W)

        # 1. 모든 delta_c 커널들에 대해 기하학적 회전 bank 생성
        stacked_weights = torch.stack(list(self.weights), dim=0)  # (c_in, Cout, Cin, k, k)
        w_geom_bank = _rotate_kernel_bank(stacked_weights.view(-1, Cin, self.k, self.k), self.g_rot)
        # w_geom_bank: (c_in*Cout, Cin, g_rot, k, k)

        # 2. 차원 재정렬 -> (c_in, Cout, Cin, g_in, g_rot, k, k)
        w_geom_bank = w_geom_bank.view(c_in, self.Cout, Cin, 1, self.g_rot, self.k, self.k)
        w_geom_bank = w_geom_bank.expand(c_in, self.Cout, Cin, g_in, self.g_rot, self.k, self.k)

        # 3. softmax weight 적용 (rotation weighting)
        if self.soft:
            weighted_kernel = F.softmax(self.weighted_rotkernel / self.temp, dim=-1)  # (g_rot,)
            weighted_kernel = weighted_kernel / weighted_kernel.max()
            weighted_kernel = weighted_kernel.view(1, 1, 1, 1, self.g_rot, 1, 1)
            w_geom_bank = w_geom_bank * weighted_kernel

        # 4. 입력 텐서 flatten: (B, c_in*Cin*g_in, H, W)
        x_flat = x.permute(0, 2, 1, 3, 4, 5).reshape(B, c_in * Cin * g_in, H, W)

        output_colors = []
        normalization_factor = 1.0 / math.sqrt(self.out_c * self.in_c * self.g_rot)

        for oc in range(self.out_c):
            # (선택) color도 gather로 Δc 구현 권장. 예시는 기존 roll 유지:
            w_c = torch.roll(w_geom_bank, shifts=-oc, dims=0)  # 주의: 축 의미가 진짜 Δc인지 확인 필수

            conv_outs = []
            for g_out in range(self.g_rot):
                m = torch.arange(g_in, device=w_c.device)
                idx = (g_out - m) % self.g_rot
                idx = idx.view(1, 1, 1, g_in, 1, 1, 1).expand(c_in, self.Cout, Cin, g_in, 1, self.k, self.k)

                w_sel = (torch.gather(w_c, dim=4, index=idx)).squeeze(4)  # (c_in, Cout, Cin, g_in, 1, k, k) --> (c_in, Cout, Cin, g_in, k, k)
                w_rel = w_sel.permute(1, 0, 2, 3, 4, 5).reshape(self.Cout, c_in * Cin * g_in, self.k, self.k)

                if self.normalization:
                    y_g = F.conv2d(x_flat, w_rel * normalization_factor, bias=None, stride=self.stride,
                                   padding=self.padding)  # (B, Cout, H_out, W_out)
                else:
                    y_g = F.conv2d(x_flat, w_rel, bias=None, stride=self.stride,
                                   padding=self.padding)  # (B, Cout, H_out, W_out)

                conv_outs.append(y_g)

            # (B, Cout, g_rot, H_out, W_out)
            y_oc = torch.stack(conv_outs, dim=2)
            output_colors.append(y_oc)

        # (B, Cout, out_c, g_rot, H_out, W_out)
        y = torch.stack(output_colors, dim=2)
        H_out, W_out = y.shape[-2:]
        y = y.view(B, self.out_c, self.Cout, self.g_rot, H_out, W_out).permute(0, 2, 1, 3, 4, 5)

        return y


    # flip + rotation
    def _forward_group_relative_2(self, x: torch.Tensor) -> torch.Tensor:
        """
        H_n × C_4 / D_4 equivariant group conv (relative indexing via gather).
        x: (B, Cin, c_in, g_in, H, W)           # if no flip dim
           or (B, Cin, c_in, f_in, g_in, H, W)  # if you already have flip dim
        returns: (B, Cout, out_c, f_out, g_out, H_out, W_out) with f_out = self.n_flip
        """
        B = x.size(0);
        Cin = x.size(1);
        c_in = x.size(2)
        has_flip_dim = (x.dim() == 7)
        if has_flip_dim:
            f_in, g_in, H, W = x.size(3), x.size(4), x.size(5), x.size(6)
            x_in = x
        else:
            # add flip dim if absent
            _, _, _, g_in, H, W = x.shape
            f_in = 1
            x_in = x.unsqueeze(3)  # (B, Cin, c_in, f_in=1, g_in, H, W)

        assert g_in == self.g_rot, "g_in must equal g_rot"

        # 1) kernel bank with rotations (+ flips if needed)
        # weights: list of (Cout, Cin, k, k) per color input => stack to (c_in, Cout, Cin, k, k)
        stacked_weights = torch.stack(list(self.weights), dim=0)  # (c_in, Cout, Cin, k, k)

        w_geom_bank = _rotate_kernel_bank(
            stacked_weights.view(-1, Cin, self.k, self.k),
            g_rot=self.g_rot,
            flip=(self.n_flip == 2)
        )
        # w_bank: (c_in*Cout, Cin, g_count, k, k)  with g_count = g_rot * n_flip
        # reshape to (c_in, Cout, Cin, f_rel, r_rel, k, k)
        g_count = w_geom_bank.shape[2]
        expected = self.g_rot * self.n_flip
        assert g_count == expected, f"kernel bank mismatch: {g_count} vs {expected}"

        w_geom_bank = w_geom_bank.view(c_in, self.Cout, Cin, self.n_flip, self.g_rot, self.k, self.k)
        if self.soft:
            weighted_kernel = F.softmax(self.weighted_rotkernel / self.temp, dim=-1)  # (g_rot,)
            weighted_kernel = weighted_kernel / weighted_kernel.max()
            weighted_kernel = weighted_kernel.view(1, 1, 1, (self.n_flip+1), self.g_rot, 1, 1)
            w_geom_bank = w_geom_bank * weighted_kernel

        # (선택) color Δc: roll 유지 (더 안전하게는 color도 gather 권장)
        # 입력 flatten: (B, c_in*Cin*f_in*g_in, H, W)
        x_flat = x_in.permute(0, 2, 1, 3, 4, 5, 6).reshape(B, c_in * Cin * f_in * g_in, H, W)

        outputs_per_color = []
        normalization_factor = 1.0 / math.sqrt(self.out_c * self.in_c * self.g_rot * (self.n_flip + 1))

        device = w_geom_bank.device
        m_rot = torch.arange(g_in, device=device)  # (g_in,)
        m_flip = torch.arange(f_in, device=device) if f_in > 1 else torch.zeros(1, device=device, dtype=torch.long)

        for oc in range(self.out_c):
            # color roll over Δc (ensure dim=0 truly encodes Δc)
            w_c = torch.roll(w_geom_bank, shifts=-oc, dims=0)  # (c_in, Cout, Cin, f_rel, r_rel, k, k)

            # expand for broadcasting over input (f_in, g_in)
            w_exp = w_c.unsqueeze(3).unsqueeze(4)  # (c_in, Cout, Cin, 1, 1, f_rel, r_rel, k, k)
            w_exp = w_exp.expand(c_in, self.Cout, Cin, f_in, g_in, self.n_flip, self.g_rot, self.k, self.k)

            conv_outs = []
            for f_out in range(self.n_flip):  # 1 (C4) or 2 (D4)
                # f_rel = f_out XOR f_in
                idx_f = (f_out ^ m_flip.view(f_in, 1))  # (f_in,1)
                idx_f = idx_f.view(1, 1, 1, f_in, 1, 1, 1, 1, 1).expand(c_in, self.Cout, Cin, f_in, g_in, 1, 1, self.k,
                                                                        self.k)
                # gather flip (dim=5)
                w_sel_f = torch.gather(w_exp, dim=5, index=idx_f)  # (c_in, Cout, Cin, f_in, g_in, 1, r_rel, k, k)

                for r_out in range(self.g_rot):
                    # r_rel = (r_out - r_in) if f_rel=0 else (r_out + r_in)   (mod g_rot)
                    idx_r0 = (r_out - m_rot) % self.g_rot  # (g_in,)
                    idx_r1 = (r_out + m_rot) % self.g_rot  # (g_in,)

                    # choose based on f_rel (the flip we just selected): f_rel = f_out XOR f_in
                    fr = (f_out ^ m_flip).view(1, f_in, 1)  # (1,f_in,1)
                    idx_r = torch.where(fr.bool(),
                                        idx_r1.view(1, 1, g_in),
                                        idx_r0.view(1, 1, g_in))  # (1,f_in,g_in)

                    idx_r = idx_r.view(1, 1, 1, f_in, g_in, 1, 1, 1, 1).expand(c_in, self.Cout, Cin, f_in, g_in, 1, 1,
                                                                               self.k, self.k)
                    # gather rot (dim=6)
                    w_sel = torch.gather(w_sel_f, dim=6, index=idx_r)  # (c_in, Cout, Cin, f_in, g_in, 1, 1, k, k)
                    w_sel = w_sel.squeeze(5).squeeze(5)  # (c_in, Cout, Cin, f_in, g_in, k, k)

                    # conv2d weight: (Cout, c_in*Cin*f_in*g_in, k, k)
                    w_rel = w_sel.permute(1, 0, 2, 3, 4, 5).reshape(self.Cout, c_in * Cin * f_in * g_in, self.k, self.k)

                    if self.normalization:
                        y = F.conv2d(x_flat, w_rel * normalization_factor, bias=None,
                                     stride=self.stride, padding=self.padding)  # (B, Cout, H_out, W_out)
                    else:
                        y = F.conv2d(x_flat, w_rel, bias=None,
                                     stride=self.stride, padding=self.padding)  # (B, Cout, H_out, W_out)
                    conv_outs.append(y)

            # stack (f_out, r_out) and reshape to (B,Cout,n_flip,g_rot,H_out,W_out)
            y_oc = torch.stack(conv_outs, dim=2)
            y_oc = y_oc.view(B, self.Cout, self.n_flip, self.g_rot, y_oc.shape[-2], y_oc.shape[-1])
            outputs_per_color.append(y_oc)

        # (B, Cout, out_c, n_flip, g_rot, H_out, W_out)
        y = torch.stack(outputs_per_color, dim=2)
        return y