from utils import do_args, run_initialize,get_datetime_str_simplified
from train import Train
import os

def main():
    print(os.getcwd())
    args = do_args()
    if 1:
        args.net_model="UNet_boost_resnet" 
        args.name_prefix="UNet_boost_resnet_" 
        args.name_suffix="" 
        args.batch_size=4
        args.training_steps_per_epoch=400
        args.validation_steps_per_epoch=400
        args.epochs=0 
        args.workers=4
        

        args.net_model="DeepLabv3p_boostno" 
        args.name_prefix="DeepLabv3p_boostno_0_" 
        args.name_suffix="" 
        args.batch_size=2
        args.training_steps_per_epoch=400
        args.validation_steps_per_epoch=500
        # args.cityscapes = True
        # if args.cityscapes:
        #     args.hr_nclasses = 20
        #     args.input_nchannels = 3
        args.isaid = True
        if args.isaid:
            args.hr_nclasses = 16
            args.input_nchannels = 3
            args.batch_size=12
            args.training_steps_per_epoch=40
            args.validation_steps_per_epoch=6000

        args.resume = True
        args.resume_model = '/home/hul/results/21-04-08-Thu_08-26-21_DeepLabv3p_boostno_syncBN_0_e100s500b8/model_state_dict_epoch_0.tar'
        
        
        args.output_dir = '/home/hul/Forest-segmentation/results'
        args.log_dir = '/home/hul/Forest-segmentation/logs_local/'
        args.name = f'{get_datetime_str_simplified()}_' \
            f"{args.name_prefix}" \
            f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
            f"{args.name_suffix}"

        # Keys for transformation of labels
        args.hr_label_key  = "/home/hul/Forest-segmentation/data/cheaseapeake_to_hr_labels.txt"
        args.lr_label_key = "/home/hul/Forest-segmentation/data/nlcd_to_lr_labels.txt"

        # COLORMAP files for labels
        args.hr_color = "/home/hul/Forest-segmentation/data/hr_color.txt"
        args.lr_color = "/home/hul/Forest-segmentation/data/nlcd_color.txt"

        # LR files used for superres loss
        args.lr_stats_mu  = "/home/hul/Forest-segmentation/data/nlcd_mu.txt"
        args.lr_stats_sigma = "/home/hul/Forest-segmentation/data/nlcd_sigma.txt"
        args.lr_class_weights = "/home/hul/Forest-segmentation/data/nlcd_class_weights.txt"

    run_initialize(args, __name__)
    train = Train(args)
    train.run()



if __name__ == '__main__':
    main()
