def none_or_str(value):
    if value == 'None':
        return None
    return value
'''
def parse_transport_args(parser):
    group = parser.add_argument_group("Transport arguments")
    group.add_argument("--path-type", type=str, default="Linear", choices=["Linear", "GVP", "VP"])
    group.add_argument("--prediction", type=str, default="velocity", choices=["velocity", "score", "noise"])
    group.add_argument("--loss-weight", type=none_or_str, default=None, choices=[None, "velocity", "likelihood"])
    group.add_argument("--sample-eps", type=float)
    group.add_argument("--train-eps", type=float)
'''

def parse_ode_args(parser):
    group = parser.add_argument_group("ODE arguments")
    group.add_argument("--sampling-method", type=str, default="dopri5", help="blackbox ODE solver methods; for full list check https://github.com/rtqichen/torchdiffeq")
    group.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance")
    group.add_argument("--rtol", type=float, default=1e-3, help="Relative tolerance")
    group.add_argument("--reverse", action="store_true")
    group.add_argument("--likelihood", action="store_true")

def parse_sde_args(parser):
    group = parser.add_argument_group("SDE arguments")
    group.add_argument("--sampling-method", type=str, default="Euler", choices=["Euler", "Heun"])
    group.add_argument("--diffusion-form", type=str, default="sigma", \
                        choices=["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"],\
                        help="form of diffusion coefficient in the SDE")
    group.add_argument("--diffusion-norm", type=float, default=1.0)
    group.add_argument("--last-step", type=none_or_str, default="Mean", choices=[None, "Mean", "Tweedie", "Euler"],\
                        help="form of last step taken in the SDE")
    group.add_argument("--last-step-size", type=float, default=0.04, \
                        help="size of the last step taken")

class DEConfig:
    def __init__(
        self,
        sampling_method="dopri5",
        atol=1e-6,
        rtol=1e-3,
        reverse=True,
        likelihood=True,

        # SDE
        diffusion_form="sigma",
        diffusion_norm=1.0,
        last_step="Mean",
        last_step_size=0.04
    ):
        super(ODEConfig, self).__init__()

        if mode == "ODE":
            self.sampling_method = sampling_method
            self.atol = atol
            self.rtol = rtol
            self.reverse = reverse
            self.likelihood = likelihood
        
        elif mode == "SDE":
            assert sampling_method in ["Euler", "Heun"]
            self.sampling_method = sampling_method
            assert diffusion_form in ["constant", "SBDM", "sigma", "linear", "decreasing", "increasing-decreasing"]
            self.diffusion_form = diffusion_form
            self.diffusion_norm = diffusion_norm
            assert last_step in [None, "Mean", "Tweedie", "Euler"]
            self.last_step = last_step
            self.last_step_size = last_step_size


        