
import json
import os
import pickle
from subprocess import Popen
import sys
import torch
from torchvision.models import resnet50

from Config import get_class_map, get_data_dir, get_names

sys.path.insert(0, '../Common/')
from DataUtils import DirectoryDataset, get_transform, get_loader
from ModelWrapper import ModelWrapper
from ResNet50 import get_model, get_features

def run(names):

    # Setup the model
    model = get_model()
    model.cuda()

    # Setup the representation extractor
    feature_hook = get_features(model)
    
    # Setup the model wrapper
    wrapper = ModelWrapper(model, transform_mode = 'imagenet', feature_hook = feature_hook)

    # For each of the chosen classes, get the model's learned representations and predictions
    class_map = get_class_map()
    for name in names:
        print(name)

        folder = class_map[name][0]
        index = int(class_map[name][1])
        data_dir = '{}/train/{}'.format(get_data_dir(), folder)
        
        out = wrapper.predict_directory(data_dir)

        with open('./Outputs/pretrained/{}.pkl'.format(name), 'wb') as f:
            pickle.dump(out, f)      
        
if __name__ == '__main__':
        
    if len(sys.argv) == 1:
        
        configs = get_names()
        
        gpu_ids = [0, 1]
        num_gpus = len(gpu_ids)
        
        # Divide the configs among the workers
        configs_worker = [[] for i in range(num_gpus)]
        next_worker = 0
        for config in configs:
            configs_worker[next_worker].append(config)
            next_worker = (next_worker + 1) % num_gpus
            
        # Save the assignments
        for i in range(num_gpus):
            with open('./{}.json'.format(i), 'w') as f:
                json.dump(configs_worker[i], f)
        
        # Launch the workers
        commands = []
        for i in range(num_gpus):
            command = 'CUDA_VISIBLE_DEVICES={} python 1-PretrainedBaseline.py {}'.format(gpu_ids[i], i)
            commands.append(command)

        procs = []
        for i in commands:
            procs.append(Popen(i, shell = True))
            
        for p in procs:
           p.wait()
        
    elif len(sys.argv) == 2:

        index = sys.argv[1]

        # Get the chosen settings
        with open('./{}.json'.format(index), 'r') as f:
            names = json.load(f)
        os.system('rm ./{}.json'.format(index))

        # Run
        run(names)
