import torch
from torch import nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from models.modules.patching2D import unpatchify_2d
import os

import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models.modules.patching2D import unpatchify_2d
import os

class FrequencyAwareReconstructionLoss(nn.Module):
    """
    时频域联合重建损失，特别适用于小波分解降维后的数据
    """
    def __init__(
        self,
        patch_size=(1, 50),
        alpha=0.7,    # 时域损失权重
        beta=0.2,     # 频域幅度损失权重
        gamma=0.1,    # 频域相位损失权重
        loss_type="smooth_l1",
        using_spectrogram=True,
        save_dir=None  # 新增保存目录参数
    ):
        super().__init__()
        self.patch_size = patch_size
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.using_spectrogram = using_spectrogram
        self.save_dir = save_dir  # 保存图像的目录
        
    # time_domain_loss和frequency_domain_loss方法保持不变
    
    def forward(self, pred, batch):
        """
        Args:
            pred:  [B, num_patches, patch_dim] from Decoder
            batch: {
               "target":      [B, num_patches, patch_dim],
               "token_mask":  [B, num_patches],
               "wave_gt_2d":  [B,1,64,1000],
               "batch_idx":   int,
               "epoch_idx":   int,
               "save_dir":    str (可选，如果不在初始化中提供)
            }
        
        Returns:
            (loss, logging_output)
        """
        target = batch["target"]
        token_mask = batch["token_mask"]
        wave_gt_2d = batch["wave_gt_2d"]
        
        # 1. 时域损失
        time_loss = self.time_domain_loss(pred, target, token_mask)
        
        # 2. 频域损失
        mag_loss, phase_loss = self.frequency_domain_loss(
            pred, target, wave_gt_2d.shape, token_mask
        )
        
        # 3. 组合损失
        total_loss = self.alpha * time_loss + self.beta * mag_loss + self.gamma * phase_loss
        
        logging_output = {
            "loss": total_loss.item(),
            "time_loss": time_loss.item(),
            "freq_mag_loss": mag_loss.item(),
            "freq_phase_loss": phase_loss.item()
        }
        
        # 可视化部分
        batch_idx = batch["batch_idx"]
        if self.using_spectrogram and (batch_idx == 100):
            # 确定保存目录，优先使用batch中的save_dir，其次使用初始化时设置的save_dir
            save_dir = batch.get("save_dir", self.save_dir)
            
            # 如果都没有提供保存目录，就使用当前目录下的'figures'
            if save_dir is None:
                save_dir = os.path.join(os.getcwd(), "figures")
            
            os.makedirs(save_dir, exist_ok=True)
            
            X = batch["wave_gt_2d"]
            B, C, H, W = X.shape
            ph, pw = self.patch_size
            patch_dim = C * ph * pw

            # unpatchify full predictions
            pred_unpatch = unpatchify_2d(pred, (ph, pw), C, H, W)
            target_unpatch = unpatchify_2d(target, (ph, pw), C, H, W)

            # 构建2D掩码
            expanded_mask = token_mask.unsqueeze(-1).repeat(1, 1, patch_dim)
            token_mask_2d = unpatchify_2d(expanded_mask, (ph, pw), C, H, W)
            token_mask_2d_bool = (token_mask_2d > 0.5)

            # 仅保留可见区域的原始图像
            masked_image = X * (~token_mask_2d_bool)
            # 重建: 可见区域保持原样，掩码区域使用预测值
            recon_with_visible = X * (~token_mask_2d_bool) + pred_unpatch * token_mask_2d_bool

            # 可视化
            b_idx, ch_idx = 0, 0
            tgt_2d = target_unpatch[b_idx, ch_idx].detach().cpu()
            pred_2d = pred_unpatch[b_idx, ch_idx].detach().cpu()
            mask_2d = (~token_mask_2d_bool[b_idx, ch_idx]).float().detach().cpu()
            masked_2d = masked_image[b_idx, ch_idx].detach().cpu()
            recon_2d = recon_with_visible[b_idx, ch_idx].detach().cpu()

            # 确定颜色映射范围
            stacked_data = torch.stack([tgt_2d, pred_2d, mask_2d, masked_2d, recon_2d], dim=0)
            vmin = stacked_data.min().item()
            vmax = stacked_data.max().item()

            # 创建图表和子图
            fig, axs = plt.subplots(2, 3, figsize=(20, 10))

            im0 = axs[0,0].imshow(tgt_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[0,0].set_title("Target")

            im1 = axs[0,1].imshow(pred_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[0,1].set_title("Pred")

            im2 = axs[0,2].imshow(mask_2d, aspect='auto', cmap='gray_r', vmin=0, vmax=1)
            axs[0,2].set_title("Mask")

            im3 = axs[1,0].imshow(masked_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[1,0].set_title("Masked")

            im4 = axs[1,1].imshow(recon_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[1,1].set_title("Recon + Visible")
            
            # 显示频谱对比
            tgt_fft = torch.abs(torch.fft.rfft(tgt_2d, dim=-1, norm="ortho")).cpu()
            pred_fft = torch.abs(torch.fft.rfft(pred_2d, dim=-1, norm="ortho")).cpu()
            fft_diff = torch.abs(tgt_fft - pred_fft)
            im5 = axs[1,2].imshow(fft_diff, aspect='auto', cmap='hot', vmin=0)
            axs[1,2].set_title("FFT Magnitude Diff")

            plt.tight_layout()

            cbar = fig.colorbar(im0, ax=axs.ravel().tolist(), fraction=0.025, pad=0.02)
            cbar.set_label('Amplitude (a.u.)', rotation=270, labelpad=15)

            # 使用参数化的保存路径
            epoch_idx = batch["epoch_idx"]
            save_path = os.path.join(save_dir, f"wave_freqloss_epoch{epoch_idx}_batch{batch_idx}.png")

            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close(fig)

            logging_output["note"] = f"Saved frequency-aware loss visualization at {save_path}"

        return total_loss, logging_output