from dmc_gen.train import main
from dmc_gen.config import Args
from params_proto.neo_hyper import Sweep

# AWS_REGIONS = ["ap-northeast-1", "ap-northeast-2", "ap-south-1", "ap-southeast-1", "ap-southeast-2", "eu-central-1", "eu-west-1", "sa-east-1", "us-east-1", "us-east-2"]
AWS_REGIONS = ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]
GCE = {'us-central1': 512, 'asia-east1': 512, 'europe-west1': 512, 'europe-west3': 256, 'europe-west2': 256,
       'us-east1': 256, 'us-east4': 256, 'asia-northeast2': 128, 'us-west3': 128, 'southamerica-east1': 128,
       'europe-central2': 128, 'us-west4': 128, 'europe-west4': 1024}
GCE = {'us-central1-b': 512, }
# GCE_REGIONS = list(GCE.keys())
GCE_REGIONS = ['us-central1-b'] * 512  # + ['us-west4-b'] * 4
# GCE_REGIONS = ['us-west4-b']

if __name__ == '__main__':
    import argparse
    import os
    import jaynes
    from ml_logger import instr, USER

    from model_free_analysis import IMAGE_IDS
    from model_free_analysis.baselines import RUN

    parser = argparse.ArgumentParser()
    parser.add_argument('sweep_file')
    args = parser.parse_args()

    sweep = Sweep(RUN, Args).load(args.sweep_file)

    for kwargs in sweep:
        thunk = instr(main, **kwargs)
        # I observed that somehow separate hash strings are generated from the same string... ????
        import hashlib
        from os.path import basename
        jobname_hash = hashlib.sha1((basename(args.sweep_file) + RUN.job_name).encode('utf-8')).hexdigest()[:10]
        job_name = RUN.job_name[-51:].replace('_', '-').replace('/', '-').replace('.', '-').lower() + '--' + jobname_hash
        # job_name = RUN.job_name[-61:].replace('_', '-').replace('/', '-').replace('.', '-').lower()
        print(RUN.prefix + RUN.job_name, jobname_hash)
        while True:
            try:
                from ml_logger import logger

                jaynes.config('gcp-gen',
                              launch={'zone': REGION,
                                      'name': job_name})
                print('launching now')
                logger.job_requested(job=dict(region=REGION, spot_request_id=jaynes.run(thunk)))
                break
            except Exception as e:
                print(e)
                if 'alreadyExists' in str(e):
                    print('============= instance already exists error captured. skip launching this =============')
                    break
                REGION = GCE_REGIONS.pop(0)
                print("switch to", REGION)

    jaynes.listen()
