from utils import do_args, run_initialize,get_datetime_str_simplified
from ddp_train import Train
import os

def main():
    print(os.getcwd())
    args = do_args()
    if 1:
        args.net_model='DeepLabv3p_boostno' #"Semantic_FPN_boost" 
        args.name_prefix='DeepLabv3p_boostno_3_' #"Semantic_FPN_boost_1_" 
        args.name_suffix="" 
        args.batch_size=2
        args.training_steps_per_epoch=40
        args.validation_steps_per_epoch=20
        args.epochs=100
        args.workers=0
        args.learning_rate=0.01
        args.valid_all=1
        # args.resume_model='/home/hul/results/21-04-29-Thu_14-13_DeepLabv3p_boost_similar_1_e80s500b2/model_state_dict_epoch_0.tar'
        # # args.net_model="DeepLabv3p_boostno" 
        # # args.name_prefix="DeepLabv3p_boostno_0_" 
        # # args.name_suffix="" 
        # # args.batch_size=3
        # # args.training_steps_per_epoch=40
        # # args.validation_steps_per_epoch=40
        # args.stagerate = 1.2
        # args.cityscapes = True
        # args.potsdam = True
        args.isaid = True
        if args.isaid:
            args.hr_nclasses = 16
            args.input_nchannels = 3
        if args.cityscapes:
            args.hr_nclasses = 19
            args.input_nchannels = 3
        elif args.isaid:
            args.hr_nclasses = 16
            args.input_nchannels = 3
        elif args.potsdam or args.vaihingen:
            args.hr_nclasses = 6
            args.input_nchannels = 3

        args.output_dir = '/home/hul/Forest-segmentation/results'
        args.log_dir = '/home/hul/Forest-segmentation/logs_local/'
        args.name = f'{get_datetime_str_simplified()}_' \
            f"{args.name_prefix}" \
            f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
            f"{args.name_suffix}"

        # Keys for transformation of labels
        args.hr_label_key  = "/home/hul/Forest-segmentation/data/cheaseapeake_to_hr_labels.txt"
        args.lr_label_key = "/home/hul/Forest-segmentation/data/nlcd_to_lr_labels.txt"

        # COLORMAP files for labels
        args.hr_color = "/home/hul/Forest-segmentation/data/hr_color.txt"
        args.lr_color = "/home/hul/Forest-segmentation/data/nlcd_color.txt"

        # LR files used for superres loss
        args.lr_stats_mu  = "/home/hul/Forest-segmentation/data/nlcd_mu.txt"
        args.lr_stats_sigma = "/home/hul/Forest-segmentation/data/nlcd_sigma.txt"
        args.lr_class_weights = "/home/hul/Forest-segmentation/data/nlcd_class_weights.txt"

    run_initialize(args, __name__)
    train = Train(args)
    train.run()



if __name__ == '__main__':
    main()
