import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from einops import rearrange
from util import *
from block import *
import inspect

class W(nn.Module): ## renamed to ISO 
    def __init__(self, channels, height, width, height_p, width_p, k=16):
        super(W, self).__init__()
        self.height_freq = height + height_p
        self.width_freq = (width + width_p)// 2 + 1

        self.psf_weights = nn.Parameter(
            torch.ones(k, channels, self.height_freq, self.width_freq) * 0.01
        )
        num_groups = get_num_groups(channels)
        self.group_norm = nn.GroupNorm(num_groups=1, num_channels=channels)

        self.alpha = nn.Parameter(torch.ones(k, 1, 1, 1) * 1,
                                           requires_grad=True)
        self.kernel_weights = nn.Parameter(generate_roi(k, height, width),
                                           requires_grad=True)
        self.relu = nn.ReLU()
        self.k = k

    def forward(self, raw: torch.Tensor, psf: torch.Tensor, epsilon=1e-6) -> torch.Tensor:
        B, C, H, W = raw.shape
        _, _, H_p, W_p = psf.shape
        
        psf = psf.reshape(self.k,-1,psf.size(-2),psf.size(-1))
        psf_sum = psf.sum(dim=(-2, -1), keepdim=True)          # (B, C, 1, 1)
        psf_normalized = psf / (psf_sum.abs() * self.alpha + 1e-12)
        
        # Apply symmetric padding to raw input
        raw_padded = F.pad(
            raw,
            (W_p // 2, W_p - W_p // 2, H_p // 2, H_p - H_p // 2),
            mode='replicate'
        )
        raw_padded = gaus_t(raw_padded, fwhm=2)
        psf_padded = F.pad(
            psf_normalized,
            (W // 2, W - W // 2, H // 2, H - H // 2),
            mode='constant'
        )
        
        # Compute FFT with 'ortho' normalization to maintain energy
        raw_fft = torch.fft.rfft2(raw_padded, dim=(-2, -1))  # Shape: (B, C, H_freq, W_freq)
        psf_fft = torch.fft.rfft2(psf_padded, s=(raw_padded.size(-2), raw_padded.size(-1)), dim=(-2, -1))  # Shape: (B, C, H_freq, W_freq)
        
        pw = self.relu(self.psf_weights)
        wiener_filter = psf_fft.conj() / (psf_fft.abs()**2 + epsilon + pw)
        out_fft = raw_fft.unsqueeze(0) * wiener_filter.unsqueeze(1)  # (k, B, C, H, W//2+1)

        out_spatial = torch.fft.irfft2(out_fft, dim=(-2, -1))
        out_spatial = torch.fft.ifftshift(out_spatial, dim=(-2, -1))
        start_H = H_p // 2
        start_W = W_p // 2
        out_cropped = out_spatial[..., start_H:start_H + H, start_W:start_W + W]  # Shape: (N, B, C, H, W)
        kw = self.kernel_weights.unsqueeze(1).unsqueeze(1)  # (k, 1, 1, H, W)
        out_cropped = (out_cropped * kw).sum(dim=0)         # (B, C, H, W)

        return self.group_norm(out_cropped.real)
    
class C(nn.Module): ## renamed to FSO
    def __init__(self, in_channels, height, width, height_p, width_p, init_scale=1.0, k=16):
        super(C, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1, 1, 1, 1) * 1)
        num_groups = get_num_groups(in_channels)
        self.group_norm = nn.GroupNorm(num_groups=1, num_channels=in_channels)
        self.k = k
    
    def forward(self, x, p):
        # Compute padding sizes
        B, C, H, W = x.shape
        _, _, H_p, W_p = p.shape
        
        # Ensure p is normalized to prevent scale amplification
        p = p.reshape(self.k, -1, p.size(-2), p.size(-1)).mean(dim=0, keepdim=True)
        p_sum = p.sum(dim=(-2, -1), keepdim=True)
        p_normalized = p / (abs(p_sum * self.alpha) + 1e-12)  # Prevent division by zero
        
        # Apply symmetric padding to raw input
        x_padded = F.pad(
            x,
            (W_p // 2, W_p - W_p // 2, H_p // 2, H_p - H_p // 2),
            mode='constant'
        )
        p_normalized = F.pad(
            p_normalized,
            (W // 2, W - W // 2, H // 2, H - H // 2),
            mode='constant'
        )
        
        X = torch.fft.rfft2(x_padded, dim=(-2, -1))  # Shape: (B, C, H_freq, W_freq)
        P = torch.fft.rfft2(p_normalized, s=(x_padded.size(-2), x_padded.size(-1)), dim=(-2, -1))  # Shape: (B, C, H_freq, W_freq)
        
        # Element-wise multiplication in frequency domain
        Y = X * P
        # Inverse FFT to spatial domain
        y = torch.fft.irfft2(Y, s=(x_padded.size(-2), x_padded.size(-1)), dim=(-2, -1))  # Shape: (B, C, H, W)
        # Shift zero frequency component to the center
        y = torch.fft.ifftshift(y, dim=(-2, -1))
        
        # Crop to original input size
        start_H = H_p // 2
        start_W = W_p // 2
        y_cropped = y[:, :, start_H:start_H + H, start_W:start_W + W]
        
        # Normalize using GroupNorm
        y_normalized = self.group_norm(y_cropped.real)
        return y_normalized.real

class CAWBlock(nn.Module):  ## renamed to IFIB
    def __init__(self, in_channels, out_channels, height, width, height_p, width_p, block_cls=DoubleConvLN, exchange=0.2, k=16):
        super(CAWBlock, self).__init__()
        self.W = W(in_channels, height, width, height_p, width_p, k=k)
        self.C = C(in_channels, height, width, height_p, width_p, k=k)

        self.block_cls = block_cls
        self.in_channels = in_channels

        self.conv1_w = self.create_block(block_cls, in_channels, out_channels)
        self.conv1_c = self.create_block(block_cls, in_channels, out_channels)
        # self.conv2_w = self.create_block(block_cls, in_channels, out_channels)
        # self.conv2_c = self.create_block(block_cls, in_channels, out_channels)
        self.res_conv_w = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.res_conv_c = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
    
        self.alpha_c = nn.Parameter(torch.full((1, 1, 1, 1), 1-exchange, dtype=torch.float32),
                                    requires_grad=True)
        self.delta_c = nn.Parameter(torch.full((1, 1, 1, 1), exchange, dtype=torch.float32),
                                    requires_grad=True)
        self.alpha_w = nn.Parameter(torch.full((1, 1, 1, 1), 1-exchange, dtype=torch.float32),
                                    requires_grad=True)
        self.delta_w = nn.Parameter(torch.full((1, 1, 1, 1), exchange, dtype=torch.float32),
                                    requires_grad=True)

    def create_block(self, block_cls, in_channels, out_channels):
        num_groups = get_num_groups(in_channels)
        return block_cls(in_channels, out_channels, num_groups=1)

    def forward(self, w, c, p):
        w_skip = w
        c_skip = c

        # Apply W and C modules
        w_ = self.W(c, p)
        c_ = self.C(w, p)

        w = self.conv1_w(w * self.alpha_w + w_ * self.delta_w)
        c = self.conv1_c(c * self.alpha_c + c_ * self.delta_c)
        
        return w, c

class UpCAW(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, height, width, height_p, width_p, block_cls=DoubleConvLN, exchange=0.2, k=16):
        super(UpCAW, self).__init__()
        self.upw = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
        self.upc = nn.Upsample(scale_factor=2, mode='bicubic', align_corners=True)
        self.caw_block = CAWBlock(mid_channels, out_channels, height, width, height_p, width_p, block_cls=block_cls, exchange=exchange, k=k)
        self.convw1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)
        self.convc1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False)

    def forward(self, w1, w2, c1, c2, p):
        w1 = self.upw(w1)
        c1 = self.upc(c1)
        diffY = w2.size()[2] - w1.size()[2]
        diffX = w2.size()[3] - w1.size()[3]
        w1 = F.pad(w1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        c1 = F.pad(c1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        w = torch.cat([w2, w1], dim=1)
        c = torch.cat([c2, c1], dim=1)
        w = self.convw1(w)
        c = self.convc1(c)
        w, c = self.caw_block(w, c, p)
        return w, c

class DownCAW(nn.Module):
    def __init__(self, in_channels, out_channels, height, width, height_p, width_p, block_cls=DoubleConvLN, exchange=0.2, k=16):
        super(DownCAW, self).__init__()
        self.pool = nn.AvgPool2d(2)
        self.caw_block = CAWBlock(in_channels, out_channels, height, width, height_p, width_p, block_cls=block_cls, exchange=exchange, k=k)
        self.conv_block = ConvG(in_channels * k, out_channels * k)

    def forward(self, w, c, p):
        w = self.pool(w)
        c = self.pool(c)
        p = self.pool(p)
        w, c = self.caw_block(w, c, p)
        p = self.conv_block(p)
        return w, c, p


class CAWNet(nn.Module): ## renamed to IFIN
    def __init__(self, in_channels, out_channels, psf, height=270, width=480, dim=32, depth=3, block_cls=DoubleConvLN, exchange=0.2, k=16, repeat=True, random=False):
        super().__init__()
        self.psf = psf
        _, _, height_p, width_p = psf.size()
        self.depth = depth
        if repeat: self.psf = nn.Parameter(psf.repeat(1,k,1,1), requires_grad=True)
        else: self.psf = nn.Parameter(psf, requires_grad=True)
        if random: nn.init.xavier_uniform_(self.psf)
        channels = [dim * (2 ** i) for i in range(depth)]
        h = [height // (2 ** i) for i in range(depth+1)]
        w = [width // (2 ** i) for i in range(depth+1)]
        h_p = [height_p // (2 ** i) for i in range(depth+1)]
        w_p = [width_p // (2 ** i) for i in range(depth+1)]
        self.WieNerH = W(in_channels, h[0], w[0], h_p[0], w_p[0], k=k)
        
        self.start_w = self.create_block(block_cls, in_channels, channels[0])
        self.start_c = self.create_block(block_cls, in_channels, channels[0])
        self.start_p = ConvG(self.psf.size(1), channels[0] * k)

        # Create downsampling layers
        self.down_layers = nn.ModuleList()
        for i in range(depth):
            if i == 0:
                self.down_layers.append(
                    DownCAW(channels[i], channels[i+1], h[i+1], w[i+1], h_p[i+1], w_p[i+1], block_cls=block_cls, exchange=exchange, k=k))
            elif i < depth - 1:
                self.down_layers.append(
                    DownCAW(channels[i], channels[i+1], h[i+1], w[i+1], h_p[i+1], w_p[i+1], block_cls=block_cls, exchange=exchange, k=k))
            else:
                # For the last downsampling layer, in_channels = out_channels
                self.down_layers.append(
                    DownCAW(channels[i], channels[i], h[i+1], w[i+1], h_p[i+1], w_p[i+1], block_cls=block_cls, exchange=exchange, k=k))

        # Create upsampling layers 
        self.up_layers = nn.ModuleList()
        for i in range(depth - 1, -1, -1):
            if i == 0:
                self.up_layers.append(
                    UpCAW(channels[0]*2, channels[0], channels[0], h[0], w[0], h_p[0], w_p[0], block_cls=block_cls, exchange=exchange, k=k)
                )
            else:
                self.up_layers.append(
                    UpCAW(channels[i]*2, channels[i-1], channels[i], h[i], w[i], h_p[i], w_p[i],block_cls=block_cls, exchange=exchange, k=k)
                )
        
        self.refine_w = self.create_block(block_cls, channels[0], channels[0])
        self.out_w = nn.Conv2d(channels[0], out_channels, kernel_size=3, padding=1, stride=1, bias=True) 
        self.refine_c = self.create_block(block_cls, channels[0], channels[0])
        self.out_c = nn.Conv2d(channels[0], out_channels, kernel_size=3, padding=1, stride=1, bias=True)

    def create_block(self, block_cls, in_channels, out_channels):
        num_groups = get_num_groups(in_channels)
        return block_cls(in_channels, out_channels, num_groups=1)

    def forward(self, x):
        p = self.psf
        x_wiener = self.WieNerH(x, p)
        w = self.start_w(x_wiener)
        c = self.start_c(x)
        p = self.start_p(p)
        
        w_feats = [w]
        c_feats = [c]
        p_feats = [p]

        # Downsampling
        for i in range(self.depth):
            w, c, p = self.down_layers[i](w_feats[-1], c_feats[-1], p_feats[-1])
            p_feats.append(p)
            w_feats.append(w)
            c_feats.append(c)

        # Upsampling
        for i in range(self.depth):
            idx = -(i + 2)
            p_feat = p_feats[idx]
            w_prev = w_feats[idx]
            c_prev = c_feats[idx]
            up_layer = self.up_layers[i]
            w, c = up_layer(w, w_prev, c, c_prev, p_feat)

        out_w = self.out_w(self.refine_w(w))
        out_c = self.out_c(self.refine_c(c))
        return out_w, out_c, x_wiener
