import torch
import torch.nn as nn
from torch.nn.functional import unfold
from utils.stats import calc_mean_std

class PPIN(nn.Module):
    def __init__(self, content_feat, div=3, ind=[]):
        super(PPIN, self).__init__()
        self.ind = ind  # List[List[int]], 每张图像的 patch 索引
        self.div = div
        self.patch_hw = 64
        #torch.Size([2, 256, 192, 192])
        self.content_feat = content_feat.clone().detach()  # (B, C, H, W)
        self.B, self.C, _, _ = self.content_feat.shape
        # print(self.content_feat.shape)

        # unfold: [B, C × K_H × K_W, N_patches]
        self.patches_unfolded = unfold(self.content_feat, kernel_size=self.patch_hw, stride=self.patch_hw)  # [B, C, N_patch, H_p, W_p]
        # print("self.patches_unfolded",self.patches_unfolded.shape) ##torch.Size([2, 147456, 1])

        self.patches_unfolded = self.patches_unfolded.transpose(1, 2).contiguous()  # [B, N, C*48*48]   torch.Size([16, 16, 589824])
        
        self.patches_unfolded = self.patches_unfolded.view(self.B, -1, self.C, self.patch_hw, self.patch_hw)  # [B, N, C, 24, 24]  torch.Size([2, 64, 256, 24, 24])
        
        self.selected_patches = []  # 收集所有需要优化的 patch
        self.style_mean = []
        self.style_std = []

        for b in range(self.B):
            for idx in self.ind[b]:
                #第b张图像的第idx个patch的
                patch = self.patches_unfolded[b, idx]  # shape: (C, H_p, W_p)   torch.Size([256, 24, 24])
                mean, std = calc_mean_std(patch.unsqueeze(0))  ## patch: [1, 256, 24, 24]  mean, std: 都是 [256, 1, 1]，会在后续用于逆变换
                mean = mean.squeeze(0)  # from [1, C, 1, 1] to [C, 1, 1]
                std = std.squeeze(0)    # same
                normalized = (patch - mean) / std

                self.selected_patches.append(normalized) # list of [256, 24, 24]
                self.style_mean.append(mean) # list of [b*n, 256, 1, 1]
                self.style_std.append(std) # list of [b*n, 256, 1, 1]

        #torch.Size([2*64, 256, 24, 24])
        #标准化后的patch
        self.selected_patches = torch.stack(self.selected_patches, dim=0)  # [B*P, C, H_p, W_p]
        
        self.style_mean = nn.Parameter(torch.stack(self.style_mean, dim=0), requires_grad=True)  # [B*P, C, 1, 1]
        self.style_std = nn.Parameter(torch.stack(self.style_std, dim=0), requires_grad=True)   # [B*P, C, 1, 1]

        self.relu = nn.ReLU(inplace=True)
        self.size = self.selected_patches[0].size()  # (C, H_p, W_p)

    def forward(self):
        # 重构：每个 patch * std + mean
        # self.selected_patches: shape = [B*P, C, H_p, W_p]，归一化后的 patch
        # self.style_std: shape = [B*P, C, 1, 1]，通过 expand → [B*P, C, H_p, W_p]
        # self.style_mean: 同理，也扩展到 [B*P, C, H_p, W_p]
        patches_prime = torch.zeros_like(self.selected_patches.clone().detach()) # (len(ind),C,H/div,W/div)
        for i in range(self.selected_patches.size(0)):
            patches_prime[i] = self.selected_patches[i] * self.style_std[i].expand(self.size) + \
                        self.style_mean[i].expand(self.size)
        patches_prime = self.relu(patches_prime)
        return patches_prime  # shape: (B*P, C, H_p, W_p)
