# Optimizer for Gaussian Processes (GP)
import torch
import gpytorch


def get_optimizer(args, model, num_data=None):
    """
    Returns the optimizer

    :param args: arguments
    :param model: model
    :param num_data: number of data points in the training set
    :returns: optimizer and learning rate scheduler
    """
    if args.model_type == "ngd_model" or args.model_type == "deepgp_ngd":
        variational_ngd_optimizer = gpytorch.optim.NGD(model.variational_parameters(), num_data=num_data, lr=0.1)
        hyperparameter_optimizer = torch.optim.Adam([{'params': model.hyperparameters()}], lr=args.initial_lr)
        opt = (hyperparameter_optimizer, variational_ngd_optimizer)
    elif args.model_type == "indep_exact":
        hyperparameter_optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)
        """
        hyperparameter_optimizer = torch.optim.Adam([{
            'params0': [model.likelihood.noise_covar.raw_noise[0],
                        model.mean_module.constant[0, 0],
                        model.covar_module.base_kernel.raw_lengthscale[0, 0, :],
                        model.covar_module.raw_outputscale[0]
                        ],
            'params1':[model.likelihood.noise_covar.raw_noise[1],
                        model.mean_module.constant[1, 0],
                        model.covar_module.base_kernel.raw_lengthscale[1, 0, :],
                        model.covar_module.raw_outputscale[1]
                        ],
            'params2': [model.likelihood.noise_covar.raw_noise[2],
                        model.mean_module.constant[2, 0],
                        model.covar_module.base_kernel.raw_lengthscale[2, 0, :],
                        model.covar_module.raw_outputscale[2]
                        ],
            'params3': [model.likelihood.noise_covar.raw_noise[3],
                        model.mean_module.constant[3, 0],
                        model.covar_module.base_kernel.raw_lengthscale[3, 0, :],
                        model.covar_module.raw_outputscale[3]
                        ],
            'params4': [model.likelihood.noise_covar.raw_noise[4],
                        model.mean_module.constant[4, 0],
                        model.covar_module.base_kernel.raw_lengthscale[4, 0, :],
                        model.covar_module.raw_outputscale[4]
                        ],
            'params5': [model.likelihood.noise_covar.raw_noise[5],
                        model.mean_module.constant[5, 0],
                        model.covar_module.base_kernel.raw_lengthscale[5, 0, :],
                        model.covar_module.raw_outputscale[5]
                        ]
        }], lr=args.initial_lr)
        """
        opt = (hyperparameter_optimizer, None)

    else:  # [exact, hetero, DeepGP]
        # opt = torch.optim.Adam(list(model.parameters()) + list(likelihood.parameters()), lr=args.initial_lr)
        # hyperparameter_optimizer = torch.optim.SGD(model.parameters(), lr=args.initial_lr)
        hyperparameter_optimizer = torch.optim.Adam(model.parameters(), lr=args.initial_lr)
        opt = (hyperparameter_optimizer, None)

    sched = None
    if args.milestones is not None:
        if args.model_type == "indep_exact":
            sched = torch.optim.lr_scheduler.ReduceLROnPlateau(hyperparameter_optimizer, factor=0.5, patience=1000,
                                                               verbose=False)
        else:
            sched = torch.optim.lr_scheduler.ReduceLROnPlateau(hyperparameter_optimizer, factor=0.5, patience=100,
                                                               verbose=False)
            #sched = torch.optim.lr_scheduler.MultiStepLR(hyperparameter_optimizer, milestones=args.milestones, gamma=0.1)
            #sched = torch.optim.lr_scheduler.CosineAnnealingLR(hyperparameter_optimizer, T_max=25, eta_min=0, last_epoch=-1)
            #sched = torch.optim.lr_scheduler.CyclicLR(hyperparameter_optimizer, base_lr=0.001, max_lr=1, step_size_up=10,mode="triangular2")
            #sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(hyperparameter_optimizer, T_0=25, T_mult=1, eta_min=0)

    return opt, sched
