import torch
from torch import nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from models.modules.patching2D import unpatchify_2d
import os

import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models.modules.patching2D import unpatchify_2d
import os

class FrequencyAwareReconstructionLoss(nn.Module):
    """
    时频域联合重建损失，特别适用于小波分解降维后的数据
    """
    def __init__(
        self,
        patch_size=(1, 50),
        alpha=0.7,    # 时域损失权重
        beta=0.2,     # 频域幅度损失权重
        gamma=0.1,    # 频域相位损失权重
        loss_type="smooth_l1",
        using_spectrogram=True
    ):
        super().__init__()
        self.patch_size = patch_size
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type
        self.using_spectrogram = using_spectrogram

    def time_domain_loss(self, pred, target, token_mask):
        """计算时域重建损失"""
        if self.loss_type == 'l1':
            loss_pixel = F.l1_loss(pred, target, reduction='none')
        elif self.loss_type == 'l2':
            loss_pixel = F.mse_loss(pred, target, reduction='none')
        elif self.loss_type == 'smooth_l1':
            loss_pixel = F.smooth_l1_loss(pred, target, reduction='none')
        else:
            raise ValueError("Invalid loss_type.")

        # 每个patch的平均误差 => [B, num_patches]
        loss_per_patch = loss_pixel.mean(dim=-1)

        # 只计算掩码区域的损失
        masked_loss = (loss_per_patch * token_mask).sum() / (token_mask.sum() + 1e-8)
        
        return masked_loss

    def frequency_domain_loss(self, pred, target, wave_gt_shape, token_mask):
        """计算频域重建损失"""
        # 解包参数
        B, C, H, W = wave_gt_shape
        ph, pw = self.patch_size
        
        # 重构完整预测和目标
        pred_unpatch = unpatchify_2d(pred, (ph, pw), C, H, W)
        target_unpatch = unpatchify_2d(target, (ph, pw), C, H, W)
        
        # 转到频域 (在时间维度上)
        pred_fft = torch.fft.rfft(pred_unpatch, dim=-1, norm="ortho")
        target_fft = torch.fft.rfft(target_unpatch, dim=-1, norm="ortho")
        
        # 幅度损失
        mag_loss = F.mse_loss(torch.abs(pred_fft), torch.abs(target_fft))
        
        # 相位损失 - 使用循环一致性损失避免相位包裹问题
        phase_pred = torch.angle(pred_fft)
        phase_target = torch.angle(target_fft)
        phase_loss = 1 - torch.cos(phase_pred - phase_target).mean()
        
        return mag_loss, phase_loss

    def forward(self, pred, batch):
        """
        Args:
            pred:  [B, num_patches, patch_dim] from Decoder
            batch: {
               "target":      [B, num_patches, patch_dim],
               "token_mask":  [B, num_patches],
               "wave_gt_2d":  [B,1,64,1000],
               "batch_idx":   int,
               "epoch_idx":   int
            }
        
        Returns:
            (loss, logging_output)
        """
        target = batch["target"]
        token_mask = batch["token_mask"]
        wave_gt_2d = batch["wave_gt_2d"]
        
        # 1. 时域损失
        time_loss = self.time_domain_loss(pred, target, token_mask)
        
        # 2. 频域损失
        mag_loss, phase_loss = self.frequency_domain_loss(
            pred, target, wave_gt_2d.shape, token_mask
        )
        
        # 3. 组合损失
        total_loss = self.alpha * time_loss + self.beta * mag_loss + self.gamma * phase_loss
        
        logging_output = {
            "loss": total_loss.item(),
            "time_loss": time_loss.item(),
            "freq_mag_loss": mag_loss.item(),
            "freq_phase_loss": phase_loss.item()
        }
        
        # 可视化代码保持基本不变
        batch_idx = batch["batch_idx"]
        if self.using_spectrogram and (batch_idx == 100):
            X = batch["wave_gt_2d"]
            B, C, H, W = X.shape
            ph, pw = self.patch_size
            patch_dim = C * ph * pw

            # unpatchify full predictions
            pred_unpatch = unpatchify_2d(pred, (ph, pw), C, H, W)
            target_unpatch = unpatchify_2d(target, (ph, pw), C, H, W)

            # 构建2D掩码
            expanded_mask = token_mask.unsqueeze(-1).repeat(1, 1, patch_dim)
            token_mask_2d = unpatchify_2d(expanded_mask, (ph, pw), C, H, W)
            token_mask_2d_bool = (token_mask_2d > 0.5)

            # 仅保留可见区域的原始图像
            masked_image = X * (~token_mask_2d_bool)
            # 重建: 可见区域保持原样，掩码区域使用预测值
            recon_with_visible = X * (~token_mask_2d_bool) + pred_unpatch * token_mask_2d_bool

            # 可视化
            b_idx, ch_idx = 0, 0
            tgt_2d = target_unpatch[b_idx, ch_idx].detach().cpu()
            pred_2d = pred_unpatch[b_idx, ch_idx].detach().cpu()
            mask_2d = (~token_mask_2d_bool[b_idx, ch_idx]).float().detach().cpu()
            masked_2d = masked_image[b_idx, ch_idx].detach().cpu()
            recon_2d = recon_with_visible[b_idx, ch_idx].detach().cpu()

            # 确定颜色映射范围
            stacked_data = torch.stack([tgt_2d, pred_2d, mask_2d, masked_2d, recon_2d], dim=0)
            vmin = stacked_data.min().item()
            vmax = stacked_data.max().item()

            # 创建图表和子图
            fig, axs = plt.subplots(2, 3, figsize=(20, 10))

            im0 = axs[0,0].imshow(tgt_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[0,0].set_title("Target")

            im1 = axs[0,1].imshow(pred_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[0,1].set_title("Pred")

            im2 = axs[0,2].imshow(mask_2d, aspect='auto', cmap='gray_r', vmin=0, vmax=1)
            axs[0,2].set_title("Mask")

            im3 = axs[1,0].imshow(masked_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[1,0].set_title("Masked")

            im4 = axs[1,1].imshow(recon_2d, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
            axs[1,1].set_title("Recon + Visible")
            
            # 新增: 显示频谱对比
            # 计算并显示目标与预测的频谱差异
            tgt_fft = torch.abs(torch.fft.rfft(tgt_2d, dim=-1, norm="ortho")).cpu()
            pred_fft = torch.abs(torch.fft.rfft(pred_2d, dim=-1, norm="ortho")).cpu()
            fft_diff = torch.abs(tgt_fft - pred_fft)
            im5 = axs[1,2].imshow(fft_diff, aspect='auto', cmap='hot', vmin=0)
            axs[1,2].set_title("FFT Magnitude Diff")

            plt.tight_layout()

            cbar = fig.colorbar(im0, ax=axs.ravel().tolist(), fraction=0.025, pad=0.02)
            cbar.set_label('Amplitude (a.u.)', rotation=270, labelpad=15)

            # 保存图像
            save_dir = "/leonardo_work/CNHPC_1526560/yanlchen/EMG_Pretrain/waveletMiM_pretraining/figures"
            os.makedirs(save_dir, exist_ok=True)

            epoch_idx = batch["epoch_idx"]
            save_path = os.path.join(save_dir, f"wave_freqloss_epoch{epoch_idx}_batch{batch_idx}.png")

            plt.savefig(save_path, dpi=200, bbox_inches='tight')
            plt.close(fig)

            logging_output["note"] = f"Saved frequency-aware loss visualization at batch_idx={batch_idx}, epoch={epoch_idx}"

        return total_loss, logging_output
