import tensorflow as tf
import numpy as np
import os
import config
from data_utils.npy_util import get_image_paths
from  data_utils import  nus_wide
opt = os.path
paras = config.get_configs()
nb_view = paras['nb_view']
image_size = paras['image_size']
w, h, c = image_size['w'], image_size['h'], image_size['c']
data_name = paras['data_name']
idx_split = paras['idx_split']

def get_data(data_base_dir='..'):
    print('Data loading ......')
    train_x = np.load(os.path.join(data_base_dir, 'train_X.npy'))
    test_x = np.load(os.path.join(data_base_dir, 'test_X.npy'))
    if c == 1:
        train_x = np.expand_dims(train_x, axis=-1)
        test_x = np.expand_dims(test_x, axis=-1)
    train_x = (train_x / 127.5) - 1.
    test_x = (test_x / 127.5) - 1.
    train_y = np.load(os.path.join(data_base_dir, 'train_Y.npy'))
    test_y = np.load(os.path.join(data_base_dir, 'test_Y.npy'))
    train_y = tf.keras.utils.to_categorical(train_y)
    test_y = tf.keras.utils.to_categorical(test_y)
    print('Data loading finished！！！')
    return train_x, train_y, test_x, test_y


def get_views(view_data_dir='views'):

    if data_name == 'nus_wide':
        print('11')
        view_train_x, train_y, view_test_x, test_y = nus_wide.load_nus_wide(
            view_data_dir=view_data_dir, idx_split=idx_split)
    else:
        models_ls= ['resnet50', 'desnet121', 'MobileNetV2', 'Xception', 'InceptionV3','resnet18', 'resnet34', 'desnet169', 'desnet201', 'NASNetMobile']
        view_train_x = []
        view_test_x = []
        for model in models_ls:
            view_train_x.append(np.load(os.path.join(view_data_dir, model+'train_X.npy')))
            view_test_x.append(np.load(os.path.join(view_data_dir, model+'test_X.npy')))
        train_y = np.load(os.path.join(view_data_dir, 'train_Y.npy'))
        test_y = np.load(os.path.join(view_data_dir, 'test_Y.npy'))
    train_y = tf.keras.utils.to_categorical(train_y)
    test_y = tf.keras.utils.to_categorical(test_y)

    return view_train_x, train_y, view_test_x, test_y


def add_gaussian_noise(features, mean=0, std=0.1):

    noise = np.random.normal(mean, std, features.shape)
    noisy_features = features + noise
    return noisy_features

if __name__ == '__main__':
    base_dir = opt.join('fn')
    train_fns, train_y, test_fns, test_y = get_image_paths(base_dir=base_dir)
    print(len(train_fns))
