import numpy as np
import os
from scipy import fftpack
def low_freq_mutate_np( amp_src, amp_trg, L=0.1 ):
    
    a_src = np.fft.fftshift( amp_src, axes=(-2, -1) )
    a_trg = np.fft.fftshift( amp_trg, axes=(-2, -1) )
    _,h, w = a_src.shape
    b = (  np.floor(np.amin((h,w))*L)  ).astype(int)
    c_h = np.floor(h/2.0).astype(int)
    c_w = np.floor(w/2.0).astype(int)

    h1 = c_h-b
    h2 = c_h+b+1
    w1 = c_w-b
    w2 = c_w+b+1

    a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
    a_src = np.fft.ifftshift( a_src, axes=(-2, -1) )
    return a_src


def FFT_source_to_target_gpt(img1, img2, structure_weight=1.0, style_weight=0.2):
    # Perform 2D FFT
    f1 = np.fft.fft2(img1)
    f2 = np.fft.fft2(img2)

    # 获取幅值和相位
    mag2 = np.abs(f2)  # img2 的幅值（纹理和风格）
    phase1 = np.angle(f1)  # img1 的相位（结构信息）

    # 融合频谱
    fused_spectrum = (style_weight * mag2) * np.exp(1j * (structure_weight * phase1))

    # 逆傅里叶变换
    fused_img = np.fft.ifft2(fused_spectrum)
    fused_img = np.abs(fused_img)  # 取幅值作为图像像素值

    return fused_img
    



def FDA_source_to_target_np( src_img, trg_img, L=0.1 ):
    # exchange magnitude
    # input: src_img, trg_img

    src_img_np = src_img #.cpu().numpy()
    trg_img_np = trg_img #.cpu().numpy()

    # get fft of both source and target
    fft_src_np = np.fft.fft2( src_img_np, axes=(-2, -1) )
    fft_trg_np = np.fft.fft2( trg_img_np, axes=(-2, -1) )

    # extract amplitude and phase of both ffts
    amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
    amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)
    # mutate the amplitude part of source with target
    amp_src_ = low_freq_mutate_np( amp_src, amp_trg, L=L )

    # mutated fft of source
    fft_src_ = amp_src_ * np.exp( 1j * pha_src )

    # get the mutated image
    src_in_trg = np.fft.ifft2( fft_src_, axes=(-2, -1) )
    src_in_trg = np.real(src_in_trg)
    return src_in_trg

from PIL import Image
import cv2
def read_img(img_name,size=(384,384)):
    img = cv2.imread(img_name)
    img = cv2.resize(img, (384, 384))
    img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    r,g,b=cv2.split(img)
    return r,g,b
if __name__=='__main__':
    import shutil
    path1 = '../dataset/domainnet/sketch/mailbox/sketch_179_000056.jpg'
    path2 = '../dataset/domainnet/real/mailbox/real_179_000324.jpg'
    path1n =path1.split('/')[-1].split('.')[0]
    path2n =path2.split('/')[-1].split('.')[0]
    shutil.copy(path1,'visual_mixup/'+path1.split('/')[-1])
    shutil.copy(path2,'visual_mixup/'+path2.split('/')[-1])
    img1=read_img(path1)
    img2=read_img(path2)
    s2t=[]
    for c1,c2 in zip(img1,img2):
        c1 = np.asarray(c1, np.float32)
        c2 = np.asarray(c2, np.float32)
        # s2t.append(FFT_source_to_target_gpt(c1,c2,style_weight=0.0001)[None,...])
        s2t.append(FDA_source_to_target_np(c1[None,...],c2[None,...],L=0.02))
    
    s2t = np.concatenate(s2t).transpose((1, 2, 0))
    s2t = 255*(s2t-s2t.min())/(s2t.max()-s2t.min())
    s2t = cv2.cvtColor(s2t.astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite('visual_mixup/'+path1n+'_'+path2n+".jpeg",s2t)