import argparse

class TrainOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        
        # ---------------------------------------- step 1/5 : parameters preparing... ----------------------------------------
        self.parser.add_argument("--seed", type=int, default=42, help="random seed")
        self.parser.add_argument("--resume", action='store_true', help="if specified, resume the training")
        self.parser.add_argument("--results_dir", type=str, default='../results', help="path of saving models, images, log files")
        self.parser.add_argument("--experiment", type=str, default='experiment', help="name of experiment")
        
        # ---------------------------------------- step 2/5 : data loading... ------------------------------------------------
        self.parser.add_argument("--train_source", type=str, default='', required=True, help="dataset root for training")
        self.parser.add_argument("--val_source1", type=str, default='', required=True, help="dataset root 1 for validating")
        self.parser.add_argument("--val_source2", type=str, default='', required=True, help="dataset root 2 for validating")
        self.parser.add_argument("--train_bs", type=int, default=4, help="size of the training batches")
        self.parser.add_argument("--val_bs", type=int, default=2, help="size of the validating batches")
        self.parser.add_argument("--crop", type=int, default=256, help="image size after cropping")
        self.parser.add_argument("--random_resize", action='store_true', help="if specified, random resize the given image")
        self.parser.add_argument("--num_workers", type=int, default=8, help="number of cpu threads to use during batch generation")
        
        # ---------------------------------------- step 3/5 : model defining... ------------------------------------------------
        self.parser.add_argument("--data_parallel", action='store_true', help="if specified, training by data paralleling")
        self.parser.add_argument("--pretrained", type=str, default=None, help="pretrained model path")
        self.parser.add_argument("--num_res", type=int, default=2, help="number of resblocks after each convolution")
        self.parser.add_argument("--model_mode", type=str, default='perc', choices=['perc','dist'], help="mode for model defination, 'perc' for perception formance and 'dist' for distortion performance")
        
        # ---------------------------------------- step 4/5 : requisites defining... ------------------------------------------------
        self.parser.add_argument("--lr", type=float, default=0.0002, help="learning rate")
        self.parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
        
        self.parser.add_argument("--gan_type", type=str, default='lsgan', help="the type of gan, optional: lsgan, vanilla")
        
        # ---------------------------------------- step 5/5 : training... ------------------------------------------------
        self.parser.add_argument("--print_gap", type=int, default=250, help="the gap between two print operations, in iteration")
        self.parser.add_argument("--val_gap", type=int, default=10, help="the gap between two validations, also the gap between two saving operation, in epoch")
        
        self.parser.add_argument("--lambda_cycle", type=float, default=10.0, help="cycle loss weight")
        self.parser.add_argument("--lambda_sr", type=float, default=5.0, help="self-regression loss weight")
        
    def parse(self, show=True):
        opt = self.parser.parse_args()
        
        if show:
            self.show(opt)
        
        return opt
    
    def show(self, opt):
        
        args = vars(opt)
        print('************ Options ************')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('************** End **************')
        
class TestOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        
        # ---------------------------------------- step 1/4 : parameters preparing... ----------------------------------------
        self.parser.add_argument("--outputs_dir", type=str, default='../outputs', help="path of saving models, images, log files")
        self.parser.add_argument("--experiment", type=str, default='experiment', help="name of experiment")
        
        # ---------------------------------------- step 2/4 : data loading... ------------------------------------------------
        self.parser.add_argument("--data_source", type=str, default='', required=True, help="dataset root")
        self.parser.add_argument("--random_resize", action='store_true', help="if specified, random resize the given image")
        self.parser.add_argument("--num_workers", type=int, default=8, help="number of cpu threads to use during batch generation")
        
        # ---------------------------------------- step 3/4 : model defining... ------------------------------------------------
        self.parser.add_argument("--pretrained_dir", type=str, default='../pretrained', help="pretrained model root")
        self.parser.add_argument("--model_name", type=str, default='', required=True, help="name of the model to be loaded")
        
        self.parser.add_argument("--num_res", type=int, default=2, help="number of resblocks after each convolution")
        self.parser.add_argument("--model_mode", type=str, default='perc', choices=['perc','dist'], help="mode for model defination, 'perc' for perception formance and 'dist' for distortion performance")
        
        # ---------------------------------------- step 4/4 : testing... ------------------------------------------------
        self.parser.add_argument("--save_image", action='store_true', help="if specified, save image when testing")
        
    def parse(self, show=True):
        opt = self.parser.parse_args()
        
        if show:
            self.show(opt)
        
        return opt
    
    def show(self, opt):
        
        args = vars(opt)
        print('************ Options ************')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('************** End **************')
    