from argparse import ArgumentParser


class TrainOptions:

	def __init__(self):
		self.parser = ArgumentParser()
		self.initialize()

	def initialize(self):
		self.parser.add_argument('--truncation_psi', type=float, help='Truncation psi', default=0.7)
		self.parser.add_argument('--truncation_cutoff', type=int, help='Truncation cutoff', default=None)
		self.parser.add_argument('--avg_camera_pivot', help='The point at which the camera looks', default=[0, 0, 0.2])
		self.parser.add_argument('--avg_camera_radius', type=float, help='The average radius of the camera', default=2.7)
		self.parser.add_argument('--fov_deg', type=float, help='Field of View of camera in degrees', default=18.837)
		
		self.parser.add_argument('--correlation_path', type=str, help='Path to correlation file', default='../pretrained_models/correlation.pt')
		self.parser.add_argument('--face_parser_ckpt', type=str, help='path to face parser file', default='../pretrained_models/BiSeNet.pth')
		self.parser.add_argument('--clip_models', help='Names of CLIP models to use for losses', default=["ViT-B/32"])  # 为什么用 "ViT-L/14@336px" 会出现 nan ？
		self.parser.add_argument('--clip_model_weights', help='Relative loss weights of the clip models', default=[1.0])
		
		self.parser.add_argument('--background_lambda', default=1.0, type=float, help='Background loss multiplier factor')
		self.parser.add_argument('--attribute_loss_lambda', type=float, help='Attribute loss multiplier factor', default=1.0)
		self.parser.add_argument('--mask_loss_lambda', type=float, help='Mask loss multiplier factor', default=1.0)
		self.parser.add_argument('--id_lambda', default=0., type=float, help='ID loss multiplier factor')
		
  
		self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
		self.parser.add_argument('--train_dataset_size', default=6000, type=int, help="Will be used only if no latents are given")
		self.parser.add_argument('--test_dataset_size', default=1000, type=int, help="Will be used only if no latents are given")

		self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
		self.parser.add_argument('--test_batch_size', default=1, type=int, help='Batch size for testing and inference')
		self.parser.add_argument('--workers', default=4, type=int, help='Number of train dataloader workers')
		self.parser.add_argument('--test_workers', default=2, type=int, help='Number of test/inference dataloader workers')

		self.parser.add_argument('--learning_rate', default=0.0005, type=float, help='Optimizer learning rate')
		self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')


		self.parser.add_argument('--parsenet_weights', default='../pretrained_models/parsenet.pth', type=str, help='Path to Parsing model weights')
		self.parser.add_argument('--stylegan_weights', default='../pretrained_models/ffhqrebalanced512-128.pkl', type=str, help='Path to StyleGAN model weights')
		self.parser.add_argument('--stylegan_size', default=512, type=int)
		self.parser.add_argument('--ir_se50_weights', default='../pretrained_models/model_ir_se50.pth', type=str, help="Path to facial recognition network used in ID loss")
		self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to AT3D model checkpoint')

		self.parser.add_argument('--max_steps', default=300001, type=int, help='Maximum number of training steps')
		self.parser.add_argument('--image_interval', default=1000, type=int, help='Interval for logging train images during training')
		self.parser.add_argument('--board_interval', default=500, type=int, help='Interval for logging metrics to tensorboard')
		self.parser.add_argument('--val_interval', default=20000, type=int, help='Validation interval')
		self.parser.add_argument('--save_interval', default=20000, type=int, help='Model checkpoint interval')


	def parse(self):
		opts = self.parser.parse_args()
		return opts