import pandas as pd
import argparse
import os
from tqdm import tqdm
def parse_args():
    parser = argparse.ArgumentParser('')
    parser.add_argument('--data_path', type=str, help='path data files',
                        default='./cc_data/train/00000/')
    parser.add_argument('--cc_data_path', type=str, help='path data files',
                        default='./cc_rewrite/old/')
    parser.add_argument('--num_split', type=int, help='the number of data each epoch deal',
                        default=4)
    parser.add_argument('--save_path', type=str, help='path data files',
                        default='./cc_rewrite/old/n_5/')
    args = parser.parse_args()
    return args

args = parse_args()

def read_txt_file(path):
    with open(path, "r") as f:
        txt = f.read()
    return txt

def load_original_cc(data_path):
    '''
    load the original cc in each file
    :param data_path: the original datafile
    :return: original cc data in dataframe file style
    '''
    original_cc = []
    for root, dirs, files in tqdm(os.walk(data_path)):
        for file in tqdm(files):
            if file[-4:] == ".jpg" and file[:2] != "._":
                filepath, txt_file_path = os.path.join(root, file), os.path.join(root, file[:-4]+".txt")
                title = read_txt_file(txt_file_path)
                # print(title, "\t", title_replace)
                # print(filepath, title)
                # print(prompt_rewrite % clean_word(title))
                original_cc.append({"filepath":filepath, "title":title})
    original_cc_df = pd.DataFrame(original_cc, columns=["filepath", "title"])
    return original_cc_df

def split_csv_files(path, save_path):
    original_file_path = os.path.join(path, "Train_GCC-training_output_5.csv")
    # print(original_file_path)

    # df = load_original_cc(args.data_path) # if do not have pre-deal dataset
    df = pd.read_csv(original_file_path, sep="\t")

    data_subsets = [df.iloc[i::args.num_split] for i in range(args.num_split)]
    for idx in range(len(data_subsets)):
        save_file_path = os.path.join(save_path, "Train_GCC-training_output_%d.csv" % idx)
        data_subset = data_subsets[idx]
        data_subset.to_csv(save_file_path, index=False, sep="\t")
    print("Done!")
if __name__ == "__main__":
    # split_csv_files(args.data_path, args.save_path) # sample
    split_csv_files(args.cc_data_path, args.save_path)
