import argparse
import os
from util import util
import torch
import models
import data


class BaseOptions():
	"""This class defines options used during both training and test time.

	It also implements several helper functions such as parsing, printing, and saving the options.
	It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
	"""

	def __init__(self):
		"""Reset the class; indicates the class hasn't been initailized"""
		self.initialized = False

	def initialize(self, parser):
		"""Define the common options that are used in both training and test."""
		# basic parameters
		parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
		parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
		parser.add_argument('--exp_id', type=str, help='id of the experiment for distinguishing different experiment settings. It decides where to store.')
		parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
		parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
		parser.add_argument('--seed', type=int, default=2023, help='Random seed for the experiment')
		# model parameters
		parser.add_argument('--model', type=str, default='uorf', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
		parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
		parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
		# dataset parameters
		parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
		parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
		parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data')
		parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
		parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
		parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
		parser.add_argument('--white_bkgd', action='store_true', help='use white background')

		# additional parameters
		parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
		parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
		parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
		parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
		parser.add_argument('--custom_lr', action='store_true', help='Custom lr(per step) scheduler for slot model. Currently hack.')
		parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing visuals on screen')
		parser.add_argument('--display_ncols', type=int, default=4,
							help='if positive, display all images in a single visdom web panel with certain number of images per row.')
		parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
		parser.add_argument('--display_server', type=str, default="http://localhost",
							help='visdom server of the web display')
		parser.add_argument('--display_env', type=str, default='main',
							help='visdom display environment name (default is "main")')
		parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
		# model settings
		parser.add_argument('--project', action='store_true', help='project the slot coord and add to slot latent')
		parser.add_argument('--relative_position', action='store_false', help='disable relative position with slot position')
		parser.add_argument('--pos_emb', action='store_true', help='apply position embedding on encoder')
		parser.add_argument('--emb_path', type=str, default='', help='path to pretrained embedding')

		# additional dummy info, save into the log file
		parser.add_argument('--dummy_info', type=str, default='', help='dummy info for code description')

		# uOCF 
		parser.add_argument('--not_strict', action='store_true', help='not strict load')
		parser.add_argument('--load_intrinsics', action='store_true', help='load camera intrinsics')
		parser.add_argument('--fg_object_size', type=float, default=3, help='size of the foreground object')
		parser.add_argument('--slot_attn_pos_emb', action='store_true', help='use position embedding in slot attention')
		parser.add_argument('--no_learnable_pos', action='store_true', help='disable learnable position embedding')

		# uOCF - transformer
		parser.add_argument('--n_feat_layers', type=int, default=1, help='number of feature layers')
		parser.add_argument('--attn_momentum', type=float, default=0.5, help='momentum in slot attention')
		parser.add_argument('--pos_init', type=str, choices=['random', 'learnable', 'zero'], default='zero', help='position initialization')
		parser.add_argument('--camera_modulation', action='store_true', help='use camera modulation in the slot attention')
		parser.add_argument('--enc_kernel_size', type=int, default=3, help='encoder kernel size')
		parser.add_argument('--enc_mode', type=str, choices=['sum', 'stack'], default='sum', help='encoder mode for MultiDINOStackEncoder')
		parser.add_argument('--dec_mlp_act', type=str, choices=['relu', 'silu'], default='relu', help='activation function in decoder mlp')
		parser.add_argument('--dec_density_act', type=str, choices=['relu', 'softplus'], default='relu', help='activation function in decoder density')

		# uOCF - predict depth
		parser.add_argument('--scaled_depth', action='store_true', help='predict depth')
		parser.add_argument('--depth_scale_pred', action='store_true', help='predict depth scale')
		parser.add_argument('--depth_scale_param', type=float, default=2., help='depth scale prediction parameter')
		parser.add_argument('--depth_scale_pred_in', type=int, default=0, help='pred scale starts from epoch x')
		parser.add_argument('--depth_scale', type=float, default=None, help='depth scale')
		parser.add_argument('--remove_duplicate', action='store_true', help='remove duplicate slots')
		parser.add_argument('--remove_duplicate_in', type=int, default=10, help='remove duplicate starts from epoch x')
		
		self.initialized = True
		return parser

	def gather_options(self):
		"""Initialize our parser with basic options(only once).
		Add additional model-specific and dataset-specific options.
		These options are defined in the <modify_commandline_options> function
		in model and dataset classes.
		"""
		if not self.initialized:  # check if it has been initialized
			parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
			parser = self.initialize(parser)

		# get the basic options
		opt, _ = parser.parse_known_args()

		# modify model-related parser options
		model_name = opt.model
		model_option_setter = models.get_option_setter(model_name)
		parser = model_option_setter(parser, self.isTrain)
		opt, _ = parser.parse_known_args()  # parse again with new defaults

		# modify dataset-related parser options
		dataset_name = opt.dataset_mode
		dataset_option_setter = data.get_option_setter(dataset_name)
		parser = dataset_option_setter(parser, self.isTrain)

		# save and return the parser
		self.parser = parser
		return parser.parse_args()

	def print_options(self, opt):
		"""Print and save options

		It will print both current options and default values(if different).
		It will save options into a text file / [checkpoints_dir] / opt.txt
		"""
		message = ''
		message += '----------------- Options ---------------\n'
		for k, v in sorted(vars(opt).items()):
			comment = ''
			default = self.parser.get_default(k)
			if v != default:
				comment = '\t[default: %s]' % str(default)
			message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
		message += '----------------- End -------------------'
		print(message)

		# save to the disk
		if self.isTrain:
			expr_dir = os.path.join(opt.checkpoints_dir, opt.name, opt.exp_id)
		else:
			expr_dir = os.path.join(opt.results_dir, opt.name, opt.exp_id)
		util.mkdirs(expr_dir)
		file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
		with open(file_name, 'wt') as opt_file:
			opt_file.write(message)
			opt_file.write('\n')

	def parse(self):
		"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
		opt = self.gather_options()
		opt.isTrain = self.isTrain   # train or test

		# process opt.suffix
		if opt.suffix:
			suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
			opt.name = opt.name + suffix

		self.print_options(opt)

		# set gpu ids
		str_ids = opt.gpu_ids.split(',')
		opt.gpu_ids = []
		for str_id in str_ids:
			id = int(str_id)
			if id >= 0:
				opt.gpu_ids.append(id)
		if len(opt.gpu_ids) > 0:
			torch.cuda.set_device(opt.gpu_ids[0])

		self.opt = opt
		return self.opt
