from timeit import default_timer as timer
from typing import Union, Optional, Callable

import torch
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer
from robustopt_torch.MDMizer import MDMizer

class StochasticSolver:
    def __init__(self, **params):
        # Parameters used configuring the solver
        self.solver_params = {"verbose" : False,
                              "progress_checks" : 0,
                              "early_stopping" : False,
                              "improvement_thresh" : 0.0001,
                              "optimizer" : "sgd",
                              "optimizer_params" : {},
                              "lr_schedule_mode" : None,
                              "scheduler_params" : {}}

        self.solver_params.update(params)

    def set_solver_params(self, **params):
        """Convenience method to set solver parameters."""
        self.solver_params.update(params)

    def solve(self, **problem_dict) -> Tensor:
        """Solve for optimal dual function values for a given distribution using SGD/MD.
Sampler is a function which draws samples from the target distribution. The
method returns a detached version of the optimal v.

        """
        # Variables for optimization
        target_vars = problem_dict["init_val"].detach().clone().requires_grad_(True)
        # Used for performing calculations without gradient computations
        dtch_target_vars = target_vars.detach()

        # Function to project onto some feasible set, if needed
        projection = problem_dict.get("projection", lambda x : None)
        projection(dtch_target_vars)

        # Total number of iterations
        num_iter = problem_dict["num_iter"]

        # Configure the optimizer
        optimizer = self._create_optimizer(target_vars, num_iter)

        # Create step size scheduler
        scheduler = self._create_scheduler(optimizer)

        # Determine how many progress checks to make and set the period accordingly
        p_period = self._get_prog_period(num_iter)

        objective = problem_dict["objective"]
        grad_sampler = problem_dict["grad_sampler"]

        # Sampler to draw validation samples
        valid_sampler = problem_dict.get("valid_sampler", grad_sampler)

        # Closure which evaluates the objective for validation
        obj_closure = lambda: objective(valid_sampler(), dtch_target_vars).mean()

        # Keeps track of the minimum objective value which has been seen and
        # whether insufficient progress was made
        obj_min = float('inf')
        no_prog = 0

        starti = timer()
        for j in range(num_iter):
            # Perform step of gradient based method
            optimizer.zero_grad()
            output = objective(grad_sampler(), target_vars).mean()
            output.backward()
            optimizer.step()

            projection(dtch_target_vars)
            self._update_lr(j, scheduler, obj_closure)

            # Do a progress check if needed
            if j % p_period == 0:
                no_prog, obj_min = self._prog_check(j, obj_closure, no_prog, obj_min)
                if no_prog > 2:
                    if self.solver_params["verbose"]:
                        print("No improvement was made for 3 consecutive progress checks. Terminating.")
                    break

        endi = timer()

        # Perform final optimizer actions if needed
        self._finalize_optimizer(optimizer)
        projection(dtch_target_vars)

        if self.solver_params["verbose"]:
            print("Completed stochastic solve in time: {:05.4f}".format(endi - starti))
            print("Final objective estimate: {}".format(obj_closure()))

        return dtch_target_vars

    def _create_optimizer(self, target_vars : Tensor, num_iter : Optional[int] = None):
        """Convenience method for creating different optimizers
        """
        if self.solver_params["optimizer"] == "sgd":
            o_params = {"lr" : 40.0}
            o_params.update(self.solver_params["optimizer_params"])
            optimizer = torch.optim.SGD([target_vars], **o_params)
        elif self.solver_params["optimizer"] == "mirror":
            o_params = {"lr" : 40.0}
            o_params.update(self.solver_params["optimizer_params"])
            # Whether or not to use the theoretically optimal step-size for
            # non-smooth functions
            scaling = o_params.pop("theoretical_scaling", False)
            optimizer = MDMizer([target_vars], **o_params)
            if scaling:
                optimizer.set_total_num_steps(num_iter)
        else:
            raise NotImplementedError("Specified optimizer is not implemented")

        return optimizer

    def _finalize_optimizer(self, optimizer : Optimizer, closure :
                            Optional[Callable[[], Tensor]] = None):
        """Convenience method for doing final actions with the optimizer
        """
        if self.solver_params["optimizer"] == "mirror":
            optimizer.set_final_iterate()

    def _create_scheduler(self, optimizer : Optimizer):
        """Convenience method for creating different schedulers according to the
        scheduler mode which is specified.
        """
        scheduler = None
        if self.solver_params["lr_schedule_mode"] == "plateau":
            # Default plateau scheduler parameters
            scheduler_params = {'mode' : 'min', 'factor' : 0.5, 'patience' : 5,
                                'threshold' : self.solver_params["improvement_thresh"],
                                'cooldown' : 3}
            scheduler_params.update(self.solver_params["scheduler_params"])
            # Extra plateau scheduler parameters
            scheduler_params.pop('period', None)
            scheduler_params["verbose"] = self.solver_params["verbose"]

            scheduler = ReduceLROnPlateau(optimizer, **scheduler_params)

        elif self.solver_params["lr_schedule_mode"] == "sublinear":
            # Default plateau scheduler parameters, lr_lambda decreases the size
            # sublinearly
            scheduler_params = {'lr_lambda' : lambda epch: 1.0 / max(epch, 1.0)}
            scheduler_params.update(self.solver_params["scheduler_params"])

            scheduler = LambdaLR(optimizer, **scheduler_params)

        return scheduler

    def _get_prog_period(self, num_iter : int):
        """Convenience method for to get the period for progress checks. Returns the period which will produce the number of progress checks which is closest to the specified number
        """
        p_checks = self.solver_params["progress_checks"]
        return int(-(-(num_iter / p_checks) // 1)) if p_checks > 0 else num_iter

    def _prog_check(self, epoch : int, closure : Callable[[], Tensor], no_prog : int,
                    obj_min : float):
        """Convenience method to check for progress on the objective function.
        """
        obj_value = closure()
        if self.solver_params["verbose"]:
            print("Epoch {} objective estimate: {}".format(epoch, obj_value))

        no_progress = no_prog
        # Determine if no progress has been made and stop early
        if self.solver_params["early_stopping"]:
            no_improv = obj_value > obj_min - self.solver_params["improvement_thresh"] * abs(obj_min)
            no_progress = no_improv * no_progress + no_improv

        obj_minimum = min(obj_value, obj_min)
        return no_progress, obj_minimum

    def _update_lr(self, epoch : int, scheduler,
                   closure : Optional[Callable[[], Tensor]] = None):
        """Convenience method for updating the stepsize, using the given scheduler.
        """
        if scheduler is None: return
        elif self.solver_params["lr_schedule_mode"] == "plateau":
            period = self.solver_params["scheduler_params"].get("period", 100)
            if epoch % period == 0:
                scheduler.step(closure())
        elif self.solver_params["lr_schedule_mode"] == "sublinear":
            scheduler.step()
