import os
import argparse
import scipy.io as scio
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


def parser_args():
    parser = argparse.ArgumentParser(description='Split data')
    parser.add_argument('--seed', default=1, type=int,
                            help='seed for initializing training. ')
    args = parser.parse_args()
    return args

def split_data(file_path, data, label, seed, dataset):
    train_data, val_test_data, train_label, val_test_label = train_test_split(data, label, test_size=0.2, random_state=seed)
    train_label_data, train_unlabel_data, train_label_label, train_unlabel_label = train_test_split(train_data, train_label, test_size=0.5, random_state=seed)
    val_data, test_data, val_label, test_label = train_test_split(val_test_data, val_test_label, test_size=0.5, random_state=seed)
    np.save(os.path.join(file_path, dataset + '_' + 'train_label_data'), train_label_data)
    np.save(os.path.join(file_path, dataset + '_' + 'train_label_label'), train_label_label)
    np.save(os.path.join(file_path, dataset + '_' + 'train_unlabel_data'), train_unlabel_data)
    np.save(os.path.join(file_path, dataset + '_' + 'train_unlabel_label'), train_unlabel_label)
    np.save(os.path.join(file_path, dataset + '_' + 'val_data'), val_data)
    np.save(os.path.join(file_path, dataset + '_' + 'val_label'), val_label)
    np.save(os.path.join(file_path, dataset + '_' + 'test_data'), test_data)
    np.save(os.path.join(file_path, dataset + '_' + 'test_label'), test_label)


def prepare_flickr(file_path, seed):
    data_path = os.path.join(file_path, 'Flickr_LDL.npz')
    data_label = np.load(data_path)
    data = data_label['x']
    label = pd.read_csv('../data/flickr/flickr_distribution.csv').iloc[:, 1:].values
    split_data(file_path, data, label, seed, 'flickr')
    

def prepare_twitter(file_path, seed):
    data_path = os.path.join(file_path, 'Twitter_LDL.npz')
    data_label = np.load(data_path)
    data = data_label['x']
    label = data_label['y']
    label = pd.read_csv('../data/twitter/twitter_distribution.csv').iloc[:, 1:].values
    split_data(file_path, data, label, seed, 'twitter')


def prepare_fbp5500(file_path, seed):
    data_path = os.path.join(file_path, 'SCUT-FBP5500.npz')
    data_label = np.load(data_path)
    data = data_label['x']
    label = data_label['y']
    split_data(file_path, data, label, seed, 'fbp5500')


def prepare_raf(file_path, seed):
    data_path = os.path.join(file_path, 'RAF-ML.npz')
    data_label = np.load(data_path)
    data = scio.loadmat('./baseDCNN.mat')['baseDCNN']
    label = data_label['y']
    split_data(file_path, data, label, seed, 'raf')


def prepare_emotion6(file_path, seed):
    data_path = os.path.join(file_path, 'Emotion6.npz')
    data_label = np.load(data_path)
    data = data_label['x']
    label = data_label['y']
    split_data(file_path, data, label,  seed, 'emotion6')

    
def main(args):
    file_path = './'
    prepare_flickr(file_path, args.seed)
    prepare_twitter(file_path, args.seed)
    prepare_fbp5500(file_path, args.seed)
    prepare_raf(file_path, args.seed)
    prepare_emotion6(file_path, args.seed)


if __name__ == '__main__':
    args = parser_args()
    main(args)
