
# run 'python 2-TrainSCVIS.py $dataset' to train the SCVIS for $dataset
# $dataset can be COCO, CelebA, or ImageNet

# Environment: scvis
import os
from multiprocessing import Pool
import pandas as pd
import sys

# Silence some warnings
os.environ['PYTHONWARNINGS']='ignore'
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

if __name__ == '__main__':
    
    dataset = sys.argv[1]
    os.chdir(dataset)
    
    sys.path.insert(0, os.getcwd())
    sys.path.insert(0, '../Common/')
    if dataset == 'ImageNet':
        from Config import get_names, get_class_map
        from Load import load_imagenet
        
        def load(name):
            return load_imagenet(name, get_class_map())
        
    else:
        from Config import get_names, get_data_dir, get_out_features
        from Load import load_standard
        
        def load(name):
            return load_standard(name, get_data_dir(), get_out_features())
        
    def run(name):
        print(name)
        base_dir = './Outputs/scvis/{}'.format(name)
        os.system('mkdir {}'.format(base_dir))
        out_tmp = '{}/tmp.tsv'.format(base_dir)

        # Get the model's representation of images from this class
        out = load(name)
        reps = out['reps']

        df = pd.DataFrame(reps)
        df.to_csv(out_tmp, sep = '\t', index = False)

        # Run scvis
        command = 'scvis train --data_matrix_file {} --out_dir {}'.format(out_tmp, base_dir)
        os.system(command)

        # Cleanup
        os.system('rm -rf {}'.format(out_tmp))
        
    # Setup the output directory
    base_dir = './Outputs/scvis'
    os.system('rm -rf {}'.format(base_dir))
    os.system('mkdir {}'.format(base_dir))
    
    # Run
    names = get_names()
    p = Pool(6)
    p.map(run, names)
    