#!/usr/bin python

import itertools
import subprocess
import os
from copy import deepcopy
import time


RUN_SCRIPT_DIR = '/' + os.path.join(*__file__.split('/')[:-1])
JOBS_DIR = os.path.join('/net/tscratch/people/plgdmnsjk', 'slurm_dmns', 'ada_depth')
# JOBS_DIR = os.path.join('/net/people/plgrid/plgdmnsjk', 'slurm_dmns', 'test_time_adaptation')

# /* TO MODIFY ---------------------------------------------
SLURM = True
CPUS_PER_TASK = 10
MEM_PER_CPU = 3


# DATADIR = '/net/pr2/projects/plgrid/plgguseby'
DATADIR = '/net/tscratch/people/plgdmnsjk/datasets'
EXP_LOGS_DIR = '/net/people/plgrid/plgdmnsjk/ada-depth/exp_logs'
CKPTS_DIR = '/net/tscratch/people/plgdmnsjk/ckpts_ada_depth'


DATA_FOLDERS = {
    'kitti': 'KITTI/kitti_data',
    'kitti_depth': 'KITTI/kitti_data',
    'kitti_c': 'KITTI',
    'waymo': 'waymo',
    'dgp': 'ddad_train_val'
}


# the dict of datasets, in the lists inside the dict should be setting names
datasets = [
    # 'kitti',
    # 'kitti_depth',
    'waymo',
    # 'dgp',
    # 'kitti_c',
]


# RUN_NAME = 'all6_dynamicBN'
# RUN_NAME = 'all6_testBN'
RUN_NAME = 'supModelSSL_GTEgo_30xloop'
# RUN_NAME = 'supModelFrozen'
# RUN_NAME = 'customSeq'

RNG_SEED=1

# all params should be in a form of array, except RUN_NAME 
configs = {
    # 'ssl_frozen': {
    #     'RUN_NAME': RUN_NAME,
    #     # 'frame_ids': ['0 -1']
    # },
    'custom': {
        'RUN_NAME': RUN_NAME,
        # 'frame_ids': ['0 -1']
        'scales': ['0'],
        'gt_transform': ['']
    },
    # 'ssl_naive': {
    #     'RUN_NAME': RUN_NAME,
    #     'learning_rate': [
    #         # 1e-3,
    #         # 1e-4,
    #         1e-5,
    #         # 1e-6
    #         ],
    #     # 'frame_ids': ['0 -1']
    # },
    # 'adadepth': {
    #     'RUN_NAME': RUN_NAME,
    #     'learning_rate': [
    #         # 1e-3,
    #         # 1e-4,
    #         1e-5,
    #         # 1e-6
    #         ],
    # }
}

# argument names from the config dict above to put into the RUN_NAME string 
args_to_run_name = [
    'frame_ids',
    'learning_rate'
    ]

common_args = f"--load_weights_folder {os.path.join(CKPTS_DIR, 'kitti_sup/models/weights_19')} \
    --models_to_load encoder depth \
    --reg_path {os.path.join(CKPTS_DIR, 'kitti_unsup/models/weights_19')} \
    --num_workers 0"

# TO MODIFY */ ---------------------------------------------

def main():
    if SLURM:
        if not os.path.exists(JOBS_DIR):
            os.mkdir(JOBS_DIR)
    
    for dataset in datasets:
        perform_experiments(dataset)
                
    print('Done!')
            
def run_command(command, method, run_name, dataset):
    if SLURM:
        t = time.localtime()
        current_time = time.strftime("%H:%M:%S", t)

        job_folder = os.path.join(JOBS_DIR, run_name + '_' +  current_time)
        if not os.path.exists(job_folder):
            os.mkdir(job_folder)

        job_file = os.path.join(job_folder, f"{run_name}.job")
        output_file = os.path.join(job_folder, f"{run_name}.out")
        error_file = os.path.join(job_folder, f"{run_name}.err")
        
        with open(job_file, 'w') as fh:
            fh.writelines("#!/bin/bash\n")
            fh.writelines(f"#SBATCH --job-name={run_name}.job\n")
            fh.writelines(f"#SBATCH --output={output_file}\n")
            fh.writelines(f"#SBATCH --error={error_file}\n")
            fh.writelines("#SBATCH --gpus=1\n")
            fh.writelines("#SBATCH --account=plgttaautopilot2-gpu-a100\n")
            fh.writelines("#SBATCH --partition=plgrid-gpu-a100\n")
            fh.writelines("#SBATCH --gres=gpu\n")
            fh.writelines("#SBATCH --time=48:00:00\n")
            fh.writelines(f"#SBATCH --cpus-per-task={CPUS_PER_TASK}\n")
            fh.writelines(f"#SBATCH --mem-per-cpu={MEM_PER_CPU}G\n")


            fh.writelines(f"cd {RUN_SCRIPT_DIR}\n")
            fh.writelines("module load Miniconda3/4.9.2\n")
            fh.writelines("eval \"$(conda shell.bash hook)\"\n")
            fh.writelines("conda activate /net/tscratch/people/plgdmnsjk/ada_depth_env\n")

            fh.writelines(command)
            fh.writelines("\n")

        os.system(f"sbatch {job_file}")
    else:
        process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        _, out = process.communicate()
        print(str(out, encoding='utf-8'))
        if process.returncode == 1:
            print("--ERROR--" * 20)

def perform_experiments(dataset):
    for method, params in configs.items():

        base_run_name = method + '_' + dataset + '_'
        tmp_params = deepcopy(params)
        if "RUN_NAME" in params.keys():
            base_run_name = base_run_name + tmp_params["RUN_NAME"]
            del tmp_params["RUN_NAME"]
        
        for params_vals in itertools.product(*tmp_params.values()):
            run_name = base_run_name
            arguments = f'--dataset {dataset} --adaptation_method {method} '
            for param_name, val in zip(tmp_params.keys(), params_vals):
                arguments += '--' + param_name + ' ' + str(val) + ' '

                if param_name in args_to_run_name:
                    if len(run_name) != 0:
                        run_name = run_name + '_' + param_name.upper() + str(val).replace(' ', '')
                    else:
                        run_name = param_name.upper() + str(val).replace(' ', '')

            arguments += '--data_path ' + os.path.join(DATADIR, DATA_FOLDERS[dataset]) + ' '

            if len(run_name) > 0:
                arguments += '--model_name' + ' ' + run_name + ' '
            
            if 'kitti' in dataset:
                arguments += '--png '
                
            command = f"python adaptation.py {arguments}{common_args}"
            
            print(command)
            
            run_command(command, method, run_name, dataset)
                
if __name__ == "__main__":
    main()

