from utils import do_args, run_initialize,get_datetime_str_simplified
from train import Train
import os

def main():
    print(os.getcwd())
    args = do_args()
    if 1:
        args.net_model="UNet_gate" 
        args.name_prefix="UNet_gate" 
        args.name_suffix="" 
        args.batch_size=4
        args.training_steps_per_epoch=20
        args.validation_steps_per_epoch=10
        args.epochs=100 
        args.workers=4
        
        
        args.net_model="UNet_boost_resnet" 
        args.name_prefix="UNet_boost_resnet1" 
        args.name_suffix="" 
        args.batch_size=3
        args.training_steps_per_epoch=400
        args.validation_steps_per_epoch=400
        args.cityscapes = True
        args.epochs=10
        if args.cityscapes:
            args.hr_nclasses = 20
            args.input_nchannels = 3
        
        args.output_dir = '/home/hul/Forest-segmentation/results'
        args.log_dir = '/home/hul/Forest-segmentation/logs_local/'
        args.name = f'{get_datetime_str_simplified()}_' \
            f"{args.name_prefix}" \
            f"e{args.epochs}s{args.training_steps_per_epoch}b{args.batch_size}"\
            f"{args.name_suffix}"

        # Keys for transformation of labels
        args.hr_label_key  = "/home/hul/Forest-segmentation/data/cheaseapeake_to_hr_labels.txt"
        args.lr_label_key = "/home/hul/Forest-segmentation/data/nlcd_to_lr_labels.txt"

        # COLORMAP files for labels
        args.hr_color = "/home/hul/Forest-segmentation/data/hr_color.txt"
        args.lr_color = "/home/hul/Forest-segmentation/data/nlcd_color.txt"

        # LR files used for superres loss
        args.lr_stats_mu  = "/home/hul/Forest-segmentation/data/nlcd_mu.txt"
        args.lr_stats_sigma = "/home/hul/Forest-segmentation/data/nlcd_sigma.txt"
        args.lr_class_weights = "/home/hul/Forest-segmentation/data/nlcd_class_weights.txt"

    run_initialize(args, __name__)
    train = Train(args)
    train.run()



if __name__ == '__main__':
    main()
