from dataset import *
from parse_args import parse_arguments
import shutil
from skimage import io
import pickle
import argparse
import tarfile


def extract_dict():
    if os.path.exists('../data/epickitchens/dict_df.pkl'):
        return

    df_verb, df_noun = load_categories('/home/sshirahm/projects/rrg-mcrowley/sshirahm/data/epickitchens/annotations')
    df = load_annotations('/home/sshirahm/projects/rrg-mcrowley/sshirahm/data/epickitchens/prepro')
    # align var name
    df.rename(columns={'noun_class': 'noun_index', 'verb_class': 'verb_index'}, inplace=True)
    df['noun_class'] = df.apply(lambda row: df_noun.loc[row.noun_index].key, axis=1)
    df['verb_class'] = df.apply(lambda row: df_verb.loc[row.verb_index].key, axis=1)

    if args.select_actions is not None:
        random.seed(args.nature_seed)
        unique_indices = df['verb_index'].unique()
        selected_indices = random.sample(list(unique_indices), min(args.select_actions, len(unique_indices)))
        df = df[df['verb_index'].isin(selected_indices)].reset_index(drop=True)

    # attributes
    dict_noun_index = {k: v for v, k in enumerate(df['noun_class'].unique())}
    dict_noun_class = {v: k for v, k in enumerate(df['noun_class'].unique())}
    dict_verb_index = {k: v for v, k in enumerate(df['verb_class'].unique())}
    dict_verb_class = {v: k for v, k in enumerate(df['verb_class'].unique())}
    df['noun_index'] = df.apply(lambda row: dict_noun_index[row.noun_class], axis=1)
    df['verb_index'] = df.apply(lambda row: dict_verb_index[row.verb_class], axis=1)

    num_instance = len(df)
    num_noun = len(df['noun_class'].unique())
    num_verb = len(df['verb_class'].unique())
    symmetric_verb_index = None

    # rebalance data
    dict_verb, dict_noun, dict_verb_noun = df_to_dict(df)
    stat_verb = dict_to_stat(dict_verb)

    # feasible combinations
    bool_verb_noun = torch.zeros((num_verb, num_noun)).bool()
    for (verb, noun) in dict_verb_noun.keys():
        bool_verb_noun[(dict_verb_index[verb], dict_noun_index[noun])] = True

    dict_verb = balance_stat(dict_verb, stat_verb, seed=args.nature_seed)
    stat_verb = dict_to_stat(dict_verb)

    indices = [name for names in dict_verb.values() for name in names]
    df = df[df.index.isin(indices)].reset_index(drop=True)

    # ood instances
    df_iid, df_ood = split_df(df, axis=args.ood, seed=args.nature_seed)

    # split
    max_num_ood = min(5000, int(0.5 * len(df_iid)))
    num_valid = max(min(len(df_ood), max_num_ood), 1)

    df_iid = df_iid.sample(frac=1, random_state=args.nature_seed)  # shuffle order
    df_train = df_iid[num_valid:]
    df_test = df_iid[:num_valid]

    # ood validation set for model selection
    if len(df_ood) < num_valid * 2:
        logging.warning("duplication between ood validation and ood test")
    else:
        logging.info("disjoint ood validation and test set")
    df_valid = df_ood[-num_valid:]
    df_ood = df_ood[:num_valid]

    df_train = df_train[:args.train_size]

    list_df = [df_train, df_valid, df_test, df_ood]
    dict = {}
    for i, df in enumerate(list_df):
        grouped = df.groupby('video_id')
        for name, group in grouped:
            if i == 0:
                samples = group.reset_index(drop=True)
                dict.update({name: samples})
            else:
                if name in dict.keys():
                    new_samples = group.reset_index(drop=True)
                    old_samples = dict[name]
                    samples = pd.concat([old_samples, new_samples], ignore_index=True)
                    dict[name] = samples
                else:
                    samples = group.reset_index(drop=True)
                    dict.update({name: samples})

    with open('../data/epickitchens/dict_df.pkl', 'wb') as f:
        pickle.dump(dict, f)


def main(args):
    with open('../data/epickitchens/dict_df.pkl', 'rb') as f:
        dict = pickle.load(f)

    key = args.p_video_id
    dataframe = dict[args.p_video_id]
    temp = os.environ.get('SLURM_TMPDIR')
    for i in range(len(dataframe)):
        sample = dataframe.iloc[i]
        start_figname = os.path.join(temp, f'frame_{sample.start_frame:010d}.jpg')
        stop_figname = os.path.join(temp, f'frame_{sample.stop_frame:010d}.jpg')

        save_folder_name = f'{args.path_data}/{sample.participant_id}/rgb_frames/{sample.video_id}'
        os.makedirs(f'{save_folder_name}', exist_ok=True)
        save_start_figname = os.path.join(save_folder_name, f'frame_{sample.start_frame:010d}.jpg')
        save_stop_figname = os.path.join(save_folder_name, f'frame_{sample.stop_frame:010d}.jpg')

        start_image = io.imread(start_figname)
        stop_image = io.imread(stop_figname)
        io.imsave(save_start_figname, start_image)
        io.imsave(save_stop_figname, stop_image)

    print(f"Saved all images in {key}")


def check_all_files():
    with open('../data/epickitchens/dict_df.pkl', 'rb') as f:
        dict = pickle.load(f)

    keys = dict.keys()
    for k in keys:
        dataframe = dict[k]
        for i in range(len(dataframe)):
            sample = dataframe.iloc[i]
            save_folder_name = f'{args.path_data}/{sample.participant_id}/rgb_frames/{sample.video_id}'
            save_start_figname = os.path.join(save_folder_name, f'frame_{sample.start_frame:010d}.jpg')
            save_stop_figname = os.path.join(save_folder_name, f'frame_{sample.stop_frame:010d}.jpg')

            try:
                _ = io.imread(save_start_figname)
                _ = io.imread(save_stop_figname)
            except FileNotFoundError:
                print(k)


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Parse main configuration file', add_help=False)
    parser.add_argument("--path_data", default='./data/epickitchens/', type=str)
    parser.add_argument("--p_video_id", default='P01_01', type=str)
    parser.add_argument("--train_size", default=5000, type=int, help='size of training data')
    parser.add_argument("--ood", default='noun', type=str)
    parser.add_argument("--nature_seed", default=1, type=int, help='seed for generating groundtruth graph')
    parser.add_argument("--select_actions", default=20, type=int)
    args = parser.parse_args()

    extract_dict()
    main(args)
    check_all_files()
