
# run 'python 1-TrainBaseline.py $dataset' to train the baseline model for $dataset
# $dataset can be COCO or CelebA

import json
import matplotlib
matplotlib.use('agg')
import numpy as np
import os
from pathlib import Path
from subprocess import Popen
import sys
import time
    
if __name__ == '__main__':
    
    # Load dataset configuration
    dataset = sys.argv[1]

    if len(sys.argv) == 2:
        os.chdir(dataset) 
        
    sys.path.insert(0, os.getcwd())
    from Config import get_data_dir, id_from_path, get_out_features

    sys.path.insert(0, '../Common/')
    from Train import run
    
    base_dir = './Outputs'

    if len(sys.argv) == 2:
        # Setup the basic configuration
        gpu_ids = [0]
        num_gpus = len(gpu_ids)

        trials = [0]
        modes = ['pretrained', 'initial-transfer', 'initial-tune']#, 'adv-tune']
        
        os.system('rm -rf {}'.format(base_dir))
        Path(base_dir).mkdir(parents = True)

        # Generate all of the configurations we want to run
        configs = []
        for mode in modes:
            for trial in trials:
                config = {}
                config['mode'] = mode
                config['trial'] = trial
                configs.append(config)

        # Divide the configs among the workers
        configs_worker = [[] for i in range(num_gpus)]
        next_worker = 0
        for config in configs:
            configs_worker[next_worker].append(config)
            next_worker = (next_worker + 1) % num_gpus

        # Save the assignments
        for i in range(num_gpus):
            with open('{}/{}.json'.format(base_dir, i), 'w') as f:
                json.dump(configs_worker[i], f)

        # Launch the workers
        commands = []
        for i in range(num_gpus):
            command = 'CUDA_VISIBLE_DEVICES={} python ../1-TrainBaseline.py {} {}'.format(gpu_ids[i], dataset, i)
            commands.append(command)

        procs = []
        for i in commands:
            procs.append(Popen(i, shell = True))
            time.sleep(np.random.uniform(4, 6))

        for p in procs:
            p.wait()
        
    elif len(sys.argv) == 3:
                
        index = sys.argv[2]
    
        # Setup dataset config
        dc = (get_data_dir(), id_from_path, get_out_features())

        # Get the chosen settings
        with open('{}/{}.json'.format(base_dir, index), 'r') as f:
            configs = json.load(f)
        os.system('rm {}/{}.json'.format(base_dir, index))

        for config in configs:
            mode = config['mode']
            trial = config['trial']

            print(config)

            run(mode, trial, dc)

            time.sleep(np.random.uniform(4, 6))