from torchvision import transforms
import random
import torch
import numpy as np
from math import sqrt

import cv2
from PIL import Image

import mmcv

def dataset_info(filepath):
    with open(filepath, 'r') as f:
        images_list = f.readlines()

    file_names = []
    labels = []
    for row in images_list:
        row = row.strip().split(' ')
        file_names.append(row[0])
        labels.append(int(row[1]))

    return file_names, labels


def get_img_transform(train=False, image_size=224, crop=False, jitter=0):
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    if train:
        if crop:
            img_transform = [transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
        else:
            img_transform = [transforms.Resize((image_size, image_size))]
        if jitter > 0:
            img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                        contrast=jitter,
                                                        saturation=jitter,
                                                        hue=min(0.5, jitter)))
        img_transform += [transforms.RandomHorizontalFlip(),
                          transforms.ToTensor(),
                          transforms.Normalize(mean, std)]
        img_transform = transforms.Compose(img_transform)
    else:
        img_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
    return img_transform


def get_pre_transform(image_size=224, crop=False, jitter=0):
    if crop:
        img_transform = [transforms.RandomResizedCrop(image_size, scale=[0.8, 1.0])]
    else:
        img_transform = [transforms.Resize((image_size, image_size))]
    if jitter > 0:
        img_transform.append(transforms.ColorJitter(brightness=jitter,
                                                    contrast=jitter,
                                                    saturation=jitter,
                                                    hue=min(0.5, jitter)))
    img_transform += [transforms.RandomHorizontalFlip(), lambda x: np.asarray(x)]
    img_transform = transforms.Compose(img_transform)
    return img_transform


def get_post_transform(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    return img_transform


def get_spectrum(img):
    img_fft = np.fft.fft2(img)
    img_abs = np.abs(img_fft)
    img_pha = np.angle(img_fft)
    return img_abs, img_pha

def get_centralized_spectrum(img):
    img_fft = np.fft.fft2(img)
    img_fft = np.fft.fftshift(img_fft)
    img_abs = np.abs(img_fft)
    img_pha = np.angle(img_fft)
    return img_abs, img_pha


def colorful_spectrum_mix(img1, img2, alpha, ratio=1.0):
    """Input image size: ndarray of [H, W, C]"""
    lam = np.random.uniform(0, alpha)

    assert img1.shape == img2.shape
    
    img1, new_scale = mmcv.imrescale(img1, scale=(225,400), return_scale= True)
    img2, new_scale = mmcv.imrescale(img2, scale=(225,400), return_scale= True)
        
    h, w, c = img1.shape
    h_crop = int(h * sqrt(ratio))
    w_crop = int(w * sqrt(ratio))
    h_start = h // 2 - h_crop // 2
    w_start = w // 2 - w_crop // 2

    img1_fft = np.fft.fft2(img1, axes=(0, 1))
    img2_fft = np.fft.fft2(img2, axes=(0, 1))
    img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft)
    img2_abs, img2_pha = np.abs(img2_fft), np.angle(img2_fft)

    img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1))
    img2_abs = np.fft.fftshift(img2_abs, axes=(0, 1))

    img1_abs_ = np.copy(img1_abs)
    img2_abs_ = np.copy(img2_abs)
    img1_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
        lam * img2_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img1_abs_[
                                                                                          h_start:h_start + h_crop,
                                                                                          w_start:w_start + w_crop]
    # img2_abs[h_start:h_start + h_crop, w_start:w_start + w_crop] = \
    #     lam * img1_abs_[h_start:h_start + h_crop, w_start:w_start + w_crop] + (1 - lam) * img2_abs_[
    #                                                                                       h_start:h_start + h_crop,
    #                                                                                       w_start:w_start + w_crop]

    img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1))
    #img2_abs = np.fft.ifftshift(img2_abs, axes=(0, 1))

    img21 = img1_abs * (np.e ** (1j * img1_pha))
    #img12 = img2_abs * (np.e ** (1j * img2_pha))
    img21 = np.real(np.fft.ifft2(img21, axes=(0, 1)))
    #img12 = np.real(np.fft.ifft2(img12, axes=(0, 1)))    
    img21, new_scale = mmcv.imrescale(img21, scale=(900,1600), return_scale= True)    
    img21 = np.uint8(np.clip(img21, 0, 255))
    #img12 = np.uint8(np.clip(img12, 0, 255))

    return img21
    

   

    
    

def reconstruct(name):
    """Input image size: ndarray of [H, W, C]"""
    
              
    img1 = mmcv.imread(name)

    img1_fft = np.fft.fft2(img1, axes=(0, 1))    
    img1_abs, img1_pha = np.abs(img1_fft), np.angle(img1_fft)   
    img1_abs = np.fft.fftshift(img1_abs, axes=(0, 1))


    spectrum_pil = Image.fromarray((20*np.log(img1_abs)).astype('uint8'))
    phase_pil = Image.fromarray((img1_pha / (2*np.pi)*255).astype('uint8'))

    # 保存图像文件
    spectrum_pil.save("spectrum_image.png")
    phase_pil.save("phase_image.png", palette="gray")


    # reconxtruct accoding to abs
    img1_pha_new = img1_pha * 0.9
    img1_abs = np.fft.ifftshift(img1_abs, axes=(0, 1))    
    img_new = img1_abs * (np.e ** (1j * img1_pha_new))
    img_new = np.real(np.fft.ifft2(img_new, axes=(0, 1)))                   
    img_new = np.uint8(np.clip(img_new, 0, 255))
    mmcv.imwrite(img_new,"am_recons.jpg")

    # reconxtruct accoding to phase
        
    img1_abs_new = img1_abs * 0.9
    img1_abs_new = np.fft.ifftshift(img1_abs_new, axes=(0, 1))    
    img_new = img1_abs_new * (np.e ** (1j * img1_pha))
    img_new = np.real(np.fft.ifft2(img_new, axes=(0, 1)))                   
    img_new = np.uint8(np.clip(img_new, 0, 255))
    mmcv.imwrite(img_new,"phase_recons.jpg")

    return img_new    
    
if __name__ == '__main__':
    reconstruct("./0.3132585284297136_ori.jpg")