import os
import argparse
import scipy.io as scio
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split


def parser_args():
    parser = argparse.ArgumentParser(description='Split data')
    parser.add_argument('--seed', default=1, type=int,
                            help='seed for initializing training. ')
    args = parser.parse_args()
    return args

def split_data(file_path, data, seed):
    train_data, val_test_data = train_test_split(data, test_size=0.2, random_state=seed)
    train_label_data, train_unlabel_data = train_test_split(train_data, test_size=0.875, random_state=seed)
    val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=seed)
    train_label_data.to_csv(os.path.join(file_path, 'train_label_data.csv'), index=0)
    train_unlabel_data.to_csv(os.path.join(file_path, 'train_unlabel_data.csv'), index=0)
    val_data.to_csv(os.path.join(file_path, 'val_data.csv'), index=0)
    test_data.to_csv(os.path.join(file_path, 'test_data.csv'), index=0)



def prepare_flickr(file_path, seed):
    data_path = os.path.join(file_path, 'flickr_distribution.csv')
    data = pd.read_csv(data_path)
    split_data(file_path, data, seed)
    

def prepare_twitter(file_path, seed):
    data_path = os.path.join(file_path, 'twitter_distribution.csv')
    data = pd.read_csv(data_path)
    split_data(file_path, data, seed)


def prepare_fbp5500(file_path, seed):
    data_path = os.path.join(file_path, 'fbp5500_distribution.csv')
    data = pd.read_csv(data_path)
    split_data(file_path, data, seed)


def prepare_raf(file_path, seed):
    data_path = os.path.join(file_path, 'distribution.txt')
    data = pd.read_csv(data_path, sep=' ', header=None)
    data.drop(data.columns[-1], axis=1, inplace=True)
    for i in range(data.shape[0]):
        name = data.iloc[i, 0].split('.')
        data.iloc[i, 0] = os.path.join('aligned', name[0] + '_aligned.' + name[1])
    split_data(file_path, data, seed)


def prepare_emotion6(file_path, seed):
    data_path = os.path.join(file_path, 'emotion6_distribution.csv')
    data = pd.read_csv(data_path)
    split_data(file_path, data, seed)

    


def main(args):
    file_path = '../../data/flickr'
    prepare_flickr(file_path, args.seed)
    file_path = '../../data/twitter'
    prepare_twitter(file_path, args.seed)
    file_path = '../../data/fbp5500'
    prepare_fbp5500(file_path, args.seed)
    file_path = '../../data/raf'
    prepare_raf(file_path, args.seed)
    file_path = '../../data/emotion6'
    prepare_emotion6(file_path, args.seed)


if __name__ == '__main__':
    args = parser_args()
    main(args)
