import os 
import random

import pandas as pd

print('add DATA_PATH')
#DATA_PATH = 

def rmv_file(synset, filename):
    synset_path = os.path.join(DATA_PATH, synset)
    os.remove(os.path.join(synset_path, filename))
    return os.path.join(synset_path, filename)

if __name__ == '__main__':
    random.seed(42)
    data_dirs = os.listdir(DATA_PATH)
    #data_dirs = [d[:-4] for d in data_dirs if '.tar' in d]
    data_dirs = [d for d in data_dirs if os.path.isdir(os.path.join(DATA_PATH, d))]

    numb_obs = [len(os.listdir(os.path.join(DATA_PATH, d))) for d in data_dirs]
    df = pd.DataFrame({'synset': data_dirs, 'numb_obs': numb_obs})
    small_subsets = {}
    small_subsets[100] = list(df.loc[df.numb_obs==100, :].synset)
    small_subsets[200] = list(df.loc[df.numb_obs==200, :].synset)
    small_subsets[300] = list(df.loc[df.numb_obs==300, :].synset)
    small_subsets[400] = list(df.loc[df.numb_obs==400, :].synset)
    small_subsets[500] = list(df.loc[df.numb_obs==500, :].synset)
    small_subsets[600] = list(df.loc[df.numb_obs==600, :].synset)
    print(f'total images: {df.numb_obs.sum()}')
    small_subsets[700] = list(df.loc[df.numb_obs==700, :].synset)
    small_subsets[800] = list(df.loc[df.numb_obs==800, :].synset)
    small_subsets[900] = list(df.loc[df.numb_obs==900, :].synset)
    removed_files = []
    synsets_rmv = []
    for d in data_dirs:
        obs_synset = os.listdir(os.path.join(DATA_PATH, d))
        big_set = True
        for k in small_subsets.keys():
            if d in small_subsets[k]:
                numb_files_2_rmv = len(obs_synset)-k
                files_2_rmv = random.sample(obs_synset, numb_files_2_rmv)
                big_set = False
        if big_set:   
            numb_files_2_rmv = len(obs_synset)-1000
            files_2_rmv = random.sample(obs_synset, numb_files_2_rmv)
        synsets_rmv += [d]*numb_files_2_rmv
        removed_files += files_2_rmv
        map_return = list(map(rmv_file, [d]*numb_files_2_rmv, files_2_rmv)) 
    rmv_df = pd.DataFrame({'synset':synsets_rmv, 'files_removed': removed_files})
    csv_path = os.path.join(DATA_PATH, 'removed_obs.csv')
    if os.path.exists(csv_path):
        rmv_df.to_csv(csv_path, mode='a', index=False, header=False)
    else:
        rmv_df.to_csv(csv_path, index=False)
    numb_obs = [len(os.listdir(os.path.join(DATA_PATH, d))) for d in data_dirs]
    df_2 = pd.DataFrame({'synset': data_dirs, 'numb_obs': numb_obs})
    print('check df_2 for numb_obs column')
    import pdb; pdb.set_trace()
