import os
import torch
import torch.fft as fft
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
from scipy.interpolate import griddata

from PIL import Image
from einops import rearrange


def fftSplitAP(x):
    
    # spatial domain 2 frequency domain
    x_freq = fft.fftshift(fft.fftn(x, dim=(-2, -1)), dim=(-2, -1))

    # extract R I 
    R_split = x_freq.real
    I_split = x_freq.imag

    # compute A P
    A_x = torch.sqrt(R_split**2 + I_split**2)
    P_x = torch.atan2(I_split, R_split)

    return A_x, P_x


def fftConbineAP(A_x, P_x):
    # Compute the real and imaginary parts
    R_prime = A_x * torch.cos(P_x)
    I_prime = A_x * torch.sin(P_x)
    
    # Combine the real and imaginary parts to form the complex frequency representation
    F_prime = torch.complex(R_prime, I_prime)

    # Perform the inverse FFT shift and inverse FFT
    x_reconstructed = torch.fft.ifftn(torch.fft.ifftshift(F_prime, dim=(-2, -1)), dim=(-2, -1)).real
    
    return x_reconstructed


def High_Enhance_Module(x, threshold, scale):
    # FFT
    x_freq = fft.fftn(x, dim=(-2, -1))
    x_freq = fft.fftshift(x_freq, dim=(-2, -1))
    
    B, C, H, W = x_freq.shape
    
    Hmax, Hmin = 64, 8
    scale *= ((H - Hmin) / (Hmax - Hmin) + 0.5)
    
    mask = torch.ones((B, C, H, W)).cuda() 
    crow, ccol = H // 2, W //2
    
    mask[..., :crow - threshold, :] *= (1+scale/2)
    mask[..., crow + threshold:, :] *= (1+scale/2)
    mask[..., :, :ccol - threshold] *= (1+scale/2)
    mask[..., :, ccol + threshold:] *= (1+scale/2)
    
    x_freq = x_freq * mask

    # IFFT
    x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
    x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
    
    return x_filtered


def Amp_Enhance_Module(x_bone, Pa, Ps):
    x_bone = x_bone.float()
    
    # Step 1: Perform FFT and split A & P
    A_bone, P_bone = fftSplitAP(x_bone)
    
    # Step 2: choose channels
    channel_contributions = torch.abs(A_bone.mean(dim=[2, 3]))  # 形状 [B, 1280]
    c_split = torch.min(channel_contributions) + Ps * (torch.max(channel_contributions) - torch.min(channel_contributions))
    small_channel_mask = (channel_contributions <= c_split).unsqueeze(-1).unsqueeze(-1)
    small_channel_mask = small_channel_mask.expand(-1, -1, A_bone.size(2), A_bone.size(3))
    
    # Step 3: Enhance A_Bone
    hidden_mean = A_bone.mean(1).unsqueeze(1)
    B = hidden_mean.shape[0]
    hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
    hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
    hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
    A_bone = torch.where(small_channel_mask, A_bone * (1 + hidden_mean * Pa), A_bone)
      
    # Step 4: reconstruct x_bone
    x_reconstructed = fftConbineAP(A_bone, P_bone)
    return x_reconstructed

