# -*- coding: utf-8 -*-
import argparse
import multiprocessing
from tqdm.contrib.concurrent import process_map
import os
from copy import deepcopy
import subprocess

from scripts import utils
parser = argparse.ArgumentParser(description='automl')
parser.add_argument('--cfg_file', default='./configs/cdan_officehome_search.yml',
                    type=str, help='the config file for hyper parameters setting')
parser.add_argument('--method', type=str, default='CDAN',
                    help="The DA training method")
parser.add_argument('--gpus', default=['0'], nargs='+', help='specify the GPUs used')
parser.add_argument('--n_trials', default=50, type=int, help='number of trials of optuna')
parser.add_argument('--opt_metric', type=str, default='accuracy')
parser.add_argument('--sampler', type=str, default='random')
parser.add_argument('--output_path', type=str, 
                    default='/home/username/DAmetric_logs/office31_cdan_%s2%s_search_entropy', 
                    help="The log dir")
parser.add_argument('--root', metavar='DIR', default='/home/username/datasets',
                        help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
                    help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
                         ' (default: Office31)')
parser.add_argument('--sources', type=str, nargs='+', default=['A'], help="source datasets list")
parser.add_argument('--targets', type=str, nargs='+', default=['W'], help="target datasets list")
parser.add_argument('--epochs', default=10, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=2000, type=int,
                    help='Number of iterations per epoch')
parser.add_argument('--seed', default=12345, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--workers', default=4, type=int, help='the number of process or thread')
parser.add_argument('--eval_mode', action='store_true', default=False,
                        help='evaluation the model')
args = parser.parse_args()

def process(resources):
        source, target, devices = resources
        device = devices.get()  # Access gpu resource if devices is not none; otherwise, wait.
        output_path = args.output_path % (source[0], target[0])
        print(output_path)
        if not args.eval_mode:
            cmd = f"CUDA_VISIBLE_DEVICES={device} python automl_optuna.py --cfg_file {args.cfg_file} " \
                f"--n_trials {args.n_trials} --opt_metric {args.opt_metric} --method {args.method} --output_path {output_path} " \
                f"--root {args.root} -d {args.data} -s {source} -t {target} --sampler {args.sampler} " \
                f"--epochs {args.epochs} -i {args.iters_per_epoch} --seed {args.seed}"
        else:
            cmd = f"CUDA_VISIBLE_DEVICES={device} python evaluation.py --method {args.method} -d {args.data} " \
                f"--output_path {output_path}"
        result = subprocess.run(cmd, shell=True)
        devices.put(device)  # Revert gpu resource
        return result

if __name__ == "__main__":
    n_gpu = len(args.gpus)
    gpus = multiprocessing.Manager().Queue(maxsize=n_gpu)
    for i in args.gpus:
        gpus.put_nowait(i)
    results = process_map(process, [(source, target, gpus) for source, target in zip(args.sources, args.targets)], max_workers=n_gpu)
    n_failed = sum(result.returncode for result in results)
    if n_failed > 0:
        for result in results:
            print(result)
    print(f"finished. total: {len(results)}, failed: {n_failed}")
    

    
    