import torch
import torch.nn as nn
import numpy as np

from utils.utils_h import CONV2D_TYPES, BATCH2D_TYPES

import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv2dLora(nn.Module):
    def __init__(self, conv_layer, rank=3, num_subjects=1, recon=None, proc=None):
        """
        Wraps a Conv2d layer with LoRA adapters.
        Args:
            conv_layer: existing nn.Conv2d layer
            rank: low-rank dimension K
            num_subjects: number of subjects
        """
        super().__init__()
        self.shared_conv = conv_layer
        self.rank = rank
        self.num_subjects = num_subjects

        # Create subject-specific low-rank factors Q, R for each subject
        # Q: (out_channels, rank, 1, 1)   → expands to filters
        # R: (rank, in_channels, kH, kW)
        self.Q = nn.Parameter(torch.zeros(num_subjects+1,
                                          conv_layer.out_channels,
                                          rank, 1, 1))
        self.R = nn.Parameter(torch.zeros(num_subjects+1,
                                          rank,
                                          conv_layer.in_channels,
                                          *conv_layer.kernel_size))

        # LoRA adapters: initialized to zero
        if recon:
            nn.init.normal_(self.Q)
            nn.init.zeros_(self.R)
        elif proc == False:
            nn.init.normal_(self.Q)
            nn.init.zeros_(self.R)

    def forward(self, x, subject_idx):
        """
        x: [B, C, H, W]
        subject_idx: [B] tensor of subject indices
        """
        out = self.shared_conv(x)

        # Subject-specific adapters
        Qs = self.Q[subject_idx]  # [B, out_c, rank, 1, 1]
        Rs = self.R[subject_idx]  # [B, rank, in_c, kH, kW]
        lora_weight = torch.einsum("borhw,bricd->boicd", Qs, Rs)  # [B, out_c, in_c, kH, kW]

        B, out_c, in_c, kH, kW = lora_weight.shape

        # Reshape for grouped conv
        x_reshaped = x.view(1, B * in_c, *x.shape[2:])  # [1, B*in_c, H, W]
        w_reshaped = lora_weight.view(B * out_c, in_c, kH, kW)  # [B*out_c, in_c, kH, kW]

        lora_out = F.conv2d(x_reshaped,
                            w_reshaped,
                            bias=None,
                            stride=self.shared_conv.stride,
                            padding=self.shared_conv.padding,
                            dilation=self.shared_conv.dilation,
                            groups=B)

        # Reshape back: [1, B*out_c, H', W'] → [B, out_c, H', W']
        lora_out = lora_out.view(B, out_c, *lora_out.shape[2:])

        out = out + lora_out
        return out


class EuclideanConvDenoiserJointLora(nn.Module):
    def __init__(self,
                 in_channels,
                 intermediate_channels,
                 out_channels,
                 kernel_timewise,
                 kernel_channelwise,
                 padding,
                 num_subjects=1,
                 rank=3,
                 recon=None,
                 proc=None
                 ):
        super(EuclideanConvDenoiserJointLora, self).__init__()

        base_conv1 = nn.Conv2d(in_channels,
                               intermediate_channels,
                               kernel_size=(kernel_timewise, 1))
        self.conv1 = Conv2dLora(base_conv1, rank=rank, num_subjects=num_subjects, recon=recon, proc=proc)
        self.Bn1 = nn.BatchNorm2d(intermediate_channels)

        base_conv2 = nn.Conv2d(intermediate_channels,
                               out_channels,
                               kernel_size=(1, kernel_channelwise),
                               padding=(0, padding))
        self.conv2 = Conv2dLora(base_conv2, rank=rank, num_subjects=num_subjects, recon=recon, proc=proc)
        self.Bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x, x_sub):
        if len(x.shape) == 3:
            x = x.unsqueeze(1)

        subj_idx = x_sub[:, 0, 0, 0].int()

        x = self.conv1(x, subj_idx)
        x = self.Bn1(x)

        x = self.conv2(x, subj_idx)
        x = self.Bn2(x)
        return x

class EuclideanConvDenoiser(nn.Module):
    def __init__(self,
                 in_channels,
                 intermediate_channels,
                 out_channels,
                 kernel_timewise,
                 kernel_channelwise,
                 padding,
                 num_subjects=1,
                 subject_embed=None,
                 subject_dim=16):
        super(EuclideanConvDenoiser, self).__init__()

        if subject_embed is not None:
            self.subject_embeddings, kernel_timewise = subject_embed(num_subjects, subject_dim)

        # we could merge these into a seq. list but this is easier for debugging
        self.conv1 = nn.Conv2d(in_channels,
                               intermediate_channels,
                               kernel_size=(kernel_timewise, 1))
        self.Bn1 = nn.BatchNorm2d(intermediate_channels)

        self.conv2 = nn.Conv2d(intermediate_channels,
                               out_channels,
                               kernel_size=(1, kernel_channelwise),
                               padding=(0, padding))
        self.Bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(1)

        x = self.conv1(x)
        x = self.Bn1(x)

        x = self.conv2(x)
        x = self.Bn2(x)
        return x


