from __future__ import absolute_import, division, print_function
import os
import argparse

file_dir = os.path.dirname(__file__) # the directory that options.py resides in


class GreeceOptions:

    def __init__(self):
        self.parser = argparse.ArgumentParser(description="GREECE options")

        #Paths
        self.parser.add_argument("--data_path",
                                type=str,
                                help="path to the training data",
                                default=os.path.join(file_dir,"mnist"))
        self.parser.add_argument("--log_dir",
                                type=str,
                                help="log directory",
                                default=os.path.join(os.path.expanduser("~"), "tmp"))

        #Training Options
        self.parser.add_argument("--optimizer",
                                 type=str,
                                 help="Type of optimizer to use",
                                 default="Adam",
                                 choices=["SGD","Adam"]
                                 )
        self.parser.add_argument("--trainer",
                                type=str,
                                help="Trainer to be used for training",
                                default="regular",
                                choices=["regular","wilds"]
                                )
        self.parser.add_argument("--dataset",
                                type=str,
                                help="dataset to train on",
                                default="Mnist",
                                choices=["ImageNet","Mnist","Cifar10","Cifar100","rotMNIST",'SVHN','iwildcam','camelyon17'])
        self.parser.add_argument("--dataset_size",
                                type=int,
                                help="training toy dataset size",
                                default=200)
        self.parser.add_argument('--split_dataset',
                                    type=float,
                                    help='Percentage to split MNIST dataset',
                                    default=-1.0)
        self.parser.add_argument("--twodig",
                                 type=str,
                                 help="Which two digits of Mnist to use as dataset. Default None, whole MNIST",
                                 default=None)

        self.parser.add_argument("--clamped",
                                help='constraint thetas to have maximum of 0.5',
                                 action="store_true")


        #Optimization Options
        self.parser.add_argument("--batch_size",
                                type=int,
                                help="batch_size from dataset",
                                default=10)
        self.parser.add_argument("--lambda_reg",
                                type=float,
                                help="Lambda coefficient for regularizer",
                                default=100.)
        self.parser.add_argument("--reduced",
                                type=float,
                                help="Dataset Percentage to keep",
                                default=1.)
        self.parser.add_argument("--learning_rate_class",
                                type=float,
                                help="Classifier's learning rate",
                                default=1e-3)
        self.parser.add_argument("--learning_rate_sampler",
                                type=float,
                                help="Sampler's learning rate",
                                default=1e-3)
        self.parser.add_argument("--slow_mix_rate",
                                type=float,
                                help="Rate of decreasing slow mix coef",
                                default=1)
        self.parser.add_argument('--lr_param',
                                 type=float,
                                 help='Rate for updating parameters in parametrized distributions')
        self.parser.add_argument("--num_epochs",
                                type=int,
                                help="Number of Epochs",
                                default=2000)
        self.parser.add_argument("--extra_epochs",
                                 type=int,
                                 help="Number of extra epochs were the augmentations are frozen",
                                 default=50)
        self.parser.add_argument("--number_samples",
                type=int,
                                help="Number of Samples from Mixture",
                                default=5)
        self.parser.add_argument("--samples_per_group",
                                type=int,
                                help="Number of Samples Per Group",
                                default=3)
        self.parser.add_argument("--param_reg",
                                 type=float,
                                 help="Regularizer applied in parameters learning inside the distribution",
                                 default=0.001)
        self.parser.add_argument("--warmup",
                                type=int,
                                help="Number of epochs to train classifier without updating sampler",
                                default=8)
        self.parser.add_argument("--no_rotating_reg",
                                help="if set, train without using the covariance regularizer",
                                action='store_true')
        self.parser.add_argument("--grad_clip_epochs",
                                type=int,
                                help="number of epochs that sampler's gradient is being clipped",
                                default=0)
        self.parser.add_argument("--tau",
                                type=float,
                                help="Gumbel Softmax tau",
                                default=0.3)
        self.parser.add_argument('--no_augment',
                                 help='if set, pefrom a regular training without augmentation',
                                 action='store_true')
        self.parser.add_argument('--var_reg',
                                 help='If set, perform regularization using variance regularizer',
                                 action='store_true')
        self.parser.add_argument('--theta_loss',
                                 type=str,
                                 default='None',
                                 choices=['None','NLLLoss'])
        self.parser.add_argument('--exp_loss',
                                 help='Train using the expected loss',
                                 action='store_true')
        self.parser.add_argument('--groups',
                                  type=str,
                                  default="None")
        #Network Options
        
        self.parser.add_argument("--sampler",
                                type=str,
                                help="sampler network to use",
                                default="AffineSampler",
                                choices=['AffineSampler','ParamSamplerCol'])
        self.parser.add_argument("--action",
                                type=str,
                                help="Group Action to Augment",
                                default="AffineActions",
                                choices=['AffineActions', 'ParamActionsCol'])
        self.parser.add_argument("--clean",
                                 help='Set to true to use the clean trainer',
                                 action='store_true')
        self.parser.add_argument("--classifier",
                                type=str,
                                help="Classifier",
                                default="ResNet",
                                choices=["ResNet","ResNet34","realResNet18","realResNet34","smallnet","wideResNet"])
        self.parser.add_argument("--param_tr",
                                help="if set, train with parameters",
                                action="store_true")


        #System Options
        self.parser.add_argument("--no_cuda",
                                help="if set disables CUDA",
                                action="store_true")
        self.parser.add_argument("--plot_every",
                                type=int,
                                help="Epochs between successive train/test plots",
                                default=5)

        #Loading Options

        #Logging Options
        self.parser.add_argument("--visualizer",
                                type=str,
                                help="Visualizer",
                                default="Mnist",
                                choices=["Mnist", "Toy2d", "Toy3d","Wilds"])
        self.parser.add_argument("--no_vis",
                                 help="if set no visualization. Higher priority than --visualizer",
                                 action="store_true")


        #Evaluation Options
        self.parser.add_argument("--grid_size",
                                type=int,
                                help="Meshgrid size to evaluate points for contour visualization",
                                default=1000)
        self.parser.add_argument("--grid_range",
                                type=int,
                                help="Meshgrid range to evaluate points for contour visualization",
                                default=10)
        self.parser.add_argument("--test_size",
                                type=int,
                                help="size of test set for toy datasets",
                                default=10)
        self.parser.add_argument("--set_const_group",type=int,
                                help="index of group to be constantly picked, default -1 (no group picked)",
                                default=-1)
        self.parser.add_argument("--exploration_rate",
                                 type=float,
                                 help="Negative rate used to update the exploration threshold",
                                 default=0.005)
        self.parser.add_argument("--reg_sched",
                                 type=float,
                                 help="Negative rate used to update the regularizer",
                                 default=0.005)



    def parse(self):
        self.options = self.parser.parse_args()
        return self.options

if __name__=="__main__":
    opt = GreeceOptions()
    print(opt.parse())










