import json
import os
import numpy as np
import torch

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir):
    images = []
    poses = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir
    
    for root, _, fnames in sorted(os.walk(dir)):
        with open(os.path.join(root, 'dataset.json'), 'rb') as f:
            labels = json.load(f)['labels']
        labels = dict(labels)
            
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
                poses.append(labels[fname.replace('\\', '/')])
        poses = np.array(poses)
        poses = poses.astype({1: np.int64, 2: np.float32}[poses.ndim])
        
    return images, poses


def aggregate_loss_dict(agg_loss_dict):
	mean_vals = {}
	for output in agg_loss_dict:
		for key in output:
			mean_vals[key] = mean_vals.setdefault(key, []) + [output[key]]
	for key in mean_vals:
		if len(mean_vals[key]) > 0:
			mean_vals[key] = sum(mean_vals[key]) / len(mean_vals[key])
		else:
			print('{} has no value'.format(key))
			mean_vals[key] = 0
	return mean_vals


def get_fourier_descriptor(mask, device):
    # 傅里叶描述符计算形状相似度
    fourier_transform = torch.fft.fft2(mask.cpu())
    # 获取傅里叶变换的幅度和相位
    magnitude_spectrum = torch.abs(fourier_transform)
    # 选择一定数量的傅里叶频率成分
    num_selected_freqs = 10
    selected_magnitudes = magnitude_spectrum.flatten().topk(num_selected_freqs).values
    # 创建傅里叶描述符
    fourier_descriptor = torch.cat([selected_magnitudes]).to(device)
    return fourier_descriptor
