#!/usr/bin/env python3

from dmc_gen.collect_offline_data import main
from dmc_gen.config import Args
from invr_thru_inf.config import Adapt, CollectData
from params_proto.neo_hyper import Sweep

# AWS_REGIONS = ["ap-northeast-1", "ap-northeast-2", "ap-south-1", "ap-southeast-1", "ap-southeast-2", "eu-central-1", "eu-west-1", "sa-east-1", "us-east-1", "us-east-2"]
AWS_REGIONS = ["us-east-1", "us-east-2", "us-west-1", "us-west-2"]
GCE = {'us-central1': 512, 'asia-east1': 512, 'asia-east2': 128,
       'asia-northeast1': 128, 'asia-northeast2': 128, 'asia-northeast3': 128,
       'asia-south1': 128, 'aisa-south2': 128, 'asia-southeast1': 128, 'asia-southeast2': 128,
       'australia-southeast1': 128, 'australia-southeast2': 128,
       'europe-central2': 128, 'europe-north1': 128,
       'europe-west1': 510, 'europe-west2': 255, 'europe-west3': 253, 'europe-west4': 128, 'europe-west6': 128,
       'northamerica-northeast1': 128, 'northamerica-northeast2': 128, 'southamerica-east1': 128,
       'us-east1': 256, 'us-east4': 256, 'us-west1': 128, 'us-west2': 128, 'us-west3': 128, 'us-west4': 128}
# GCE = {'us-central1-b': 512, }
# GCE_REGIONS = list(GCE.keys())

# T4 regions
GCE_REGIONS = [
    # 'us-central1-b', 'us-central1-f', 'us-east1-c', 'us-east1-d', 'us-east1-b', 'us-central1-a',
    # 'us-central1-f', 'us-east1-c', 'us-east1-d', 'us-east1-b', 'us-central1-a',
    'us-east1-c', 'us-east1-d', 'us-east1-b', 'us-central1-a',
    'us-west1-a', 'us-west1-b', 'us-west2-b', 'us-west2-c', 'us-west4-a', 'us-west4-b', 'us-east4-b'
    # 'us-west4-a', 'us-west4-b',
]

if __name__ == '__main__':
    import os
    import argparse
    import jaynes
    from ml_logger import instr, USER

    from model_free_analysis import IMAGE_IDS
    from model_free_analysis.baselines import RUN
    from copy import deepcopy
    import time

    # RUN.prefix = RUN.prefix.replace('{file_stem}', 'baselines/dmc_gen/run')

    parser = argparse.ArgumentParser()
    parser.add_argument('sweep_file')
    args = parser.parse_args()

    sweep = Sweep(RUN, Args, Adapt, CollectData).load(args.sweep_file)

    # REGION = GCE_REGIONS.pop(0)
    launch_logs = []
    while len(sweep) > 0:
        for kwargs in sweep:
            thunk = instr(main, **kwargs)

            import hashlib
            from os.path import basename
            jobname_hash = hashlib.sha1((basename(args.sweep_file) + RUN.job_name).encode('utf-8')).hexdigest()[:10]
            job_name = RUN.job_name[-51:].replace('_', '-').replace('/', '-').replace('.', '-').lower() + '--' + jobname_hash

            if job_name.startswith('-'):
                job_name = job_name[1:]

            while True:
                try:
                    from ml_logger import logger
                    import random

                    # Sample randomly
                    REGION = random.sample(GCE_REGIONS, 1)[0]
                    jaynes.config('gcp-gen',
                                  launch={'zone': REGION,
                                          'name': job_name})
                    print('launching now')
                    operation = jaynes.run(thunk)
                    logger.job_requested(job=dict(region=REGION, spot_request_id=operation['id']))
                    launch_logs.append({'operation': operation['name'], 'zone': REGION, 'kwargs': deepcopy(kwargs)})
                    break
                except Exception as e:
                    print(e)
                    if 'alreadyExists' in str(e):
                        print('============= instance already exists error captured. skip launching this =============')
                        break
                    # REGION = GCE_REGIONS.pop(0)
                    GCE_REGIONS = [r for r in GCE_REGIONS if r != REGION]
                    # print("switch to", REGION)

        print('waiting for 8 seconds to make sure the resource exhaust error can show up on the log...')
        time.sleep(8)
        # Check if everything is running fine
        project = os.environ.get('JYNS_GCP_PROJECT')
        import googleapiclient.discovery
        compute = googleapiclient.discovery.build('compute', 'v1')
        remainings = []
        blacklist_zones = []
        for launched in launch_logs:
            print(launched['kwargs']['RUN.job_name'])
            result = compute.zoneOperations().get(
                project=project,
                zone=launched['zone'],
                operation=launched['operation']).execute()
            if 'error' in result:
                print(result['error'])
                remainings.append(launched['kwargs'])
                blacklist_zones.append(launched['zone'])
        print('blacklist zones', blacklist_zones)
        GCE_REGIONS = [region for region in GCE_REGIONS if region not in blacklist_zones]
        print(f'=========== Relaunching {len(remainings)} runs =============')
        sweep = Sweep(RUN, Args, Adapt, CollectData).load(remainings)
        REGION = GCE_REGIONS.pop(0)
        print("switch to", REGION)
        launch_logs = []  # Reset launch logs

    jaynes.listen()
