
import os
import os.path as osp
import io
import numpy as np
import argparse
import datetime
import importlib
import configparser
from tqdm import tqdm
from mmdet.models import build_detector

import torch
#import autotorch as at
from mmcv import Config

from mmdet.models import build_detector

try:
    from mmcv.cnn import get_model_complexity_info
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')


def rand_scales():
    scales = np.arange(0.1, 3.1, 0.1)
    np.random.shuffle(scales)
    keep = np.random.randint(8, 20)
    scales = scales[:keep]
    scales.sort()
    return list(scales)


def merge_cfg(cfg, scales):
    data_cfg = cfg['data']
    train_pipeline = data_cfg['train']['pipeline']
    for pl in train_pipeline:
        if pl['type'] == 'RandomSquareCrop':
            pl['crop_choice'] = scales
    return cfg

def get_args():
    parser = argparse.ArgumentParser(description='Auto-SCRFD')
    # config files
    parser.add_argument('--group', type=str, default='configs/scale2.5g', help='configs work dir')
    parser.add_argument('--template', type=int, default=0, help='template config index')
    #parser.add_argument('--gflops', type=float, default=2.5, help='expected flops')
    #parser.add_argument('--mode', type=int, default=1, help='generation mode, 1 for searching backbone, 2 for search all')
    # target flops
    #parser.add_argument('--eps', type=float, default=2e-2, help='eps for expected flops')
    # num configs
    parser.add_argument('--num-configs', type=int, default=64, help='num of expected configs')
    parser = parser

    args = parser.parse_args()
    return args



def main():
    args = get_args()
    print(datetime.datetime.now())

    input_shape = (3,480,640)
    runtime_input_shape = input_shape
    flops_mult = 1.0

    assert osp.exists(args.group)
    group_name = args.group.split('/')[-1]
    assert len(group_name)>0
    input_template = osp.join(args.group, "%s_%d.py"%(group_name, args.template))
    assert osp.exists(input_template)
    write_index = args.template+1
    #while True:
    #    output_cfg = osp.join(args.group, "%s_%d.py"%(group_name, write_index))
    #    if not osp.exists(output_cfg):
    #        break
    #    write_index+=1
    print('write-index from:', write_index)




    pp = 0
    write_count = 0
    while write_count < args.num_configs:
        pp+=1
        det_cfg = Config.fromfile(input_template)
        scales = rand_scales()
        det_cfg = merge_cfg(det_cfg, scales)

        output_cfg_file = osp.join(args.group, "%s_%d.py"%(group_name, write_index))
        det_cfg.dump(output_cfg_file)
        print('SUCC', write_index, scales, datetime.datetime.now())
        write_index += 1
        write_count += 1

if __name__ == '__main__':
    main()

