from timeit import default_timer as timer
from typing import Union, Optional, Callable, Dict, Any

# import line_profiler
# import atexit
# profile = line_profiler.LineProfiler()
# atexit.register(profile.print_stats)

import torch
from torch import Tensor
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim.lr_scheduler import LambdaLR
from torch.optim import Optimizer

from robustopt_torch.AGDMizer import AGDMizer

class GenericSolver:
    def __init__(self, **params):
        # Parameters used configuring the solver
        self.solver_params = {"verbose" : False,
                              "optimizer" : "gd",
                              "optimizer_params" : {},
                              "lr_schedule_mode" : None,
                              "scheduler_params" : {}}

        self.solver_params.update(params)
        self._default_lr = 1e-1

    def set_solver_params(self, **params):
        """Convenience method to set solver parameters."""
        self.solver_params.update(params)

    def solve(self, **problem_dict) -> Tensor:
        """Solve for optimal dual function values for a given distribution using GD or AGD.
        """
        # Custom optimization loop for a regularized solve with gradient stopping
        if "regularizer" in problem_dict:
            return self._regularized_solve(**problem_dict)

        # Variables for optimization
        target_vars = problem_dict["init_val"].detach().clone().requires_grad_(True)
        # Used for performing calculations without gradient computations
        dtch_target_vars = target_vars.detach()

        # Function to project onto some feasible set, if needed
        projection = problem_dict.get("projection", lambda x : None)
        projection(dtch_target_vars)

        # Total number of iterations
        num_iter = problem_dict["num_iter"]

        # Configure the optimizer
        optimizer = self._create_optimizer(target_vars)

        # Create step size scheduler
        scheduler = self._create_scheduler(optimizer)

        objective = problem_dict["objective"]

        # Closure which evaluates the objective for validation
        obj_closure = lambda: objective(target_vars)

        # Call back to stop early
        stopping_callback = problem_dict.get("stopping_callback", lambda x, y : False)

        starti = timer()
        for j in range(num_iter):
            # Get gradients
            optimizer.zero_grad()
            output = objective(target_vars)
            output.backward()

            # Determine if early stopping should occur
            if stopping_callback(output, target_vars):
                print("Iterations reached stopping criterion. " \
                      "Terminating early at epoch {}".format(j))
                break

            # Take gradient step
            optimizer.step()

            projection(dtch_target_vars)
            self._update_lr(j, scheduler, obj_closure)

        endi = timer()

        # Perform final optimizer actions if needed
        self._finalize_optimizer(optimizer, obj_closure)
        projection(dtch_target_vars)

        if self.solver_params["verbose"]:
            print("Completed generic solve in time: {:05.4f}".format(endi - starti))
            print("Final objective value: {}".format(objective(dtch_target_vars)))

        return dtch_target_vars

    def _create_optimizer(self, target_vars : Tensor, problem_dict :
                          Dict[str, Any] = {}):
        """Convenience method for creating different optimizers
        """
        if self.solver_params["optimizer"] == "gd":
            o_params = {"lr" : problem_dict.get("lr", self._default_lr)}
            o_params.update(self.solver_params["optimizer_params"])
            optimizer = torch.optim.SGD([target_vars], **o_params)
        elif self.solver_params["optimizer"] == "agd":
            o_params = {"lr" : problem_dict.get("lr", self._default_lr)}
            o_params.update(self.solver_params["optimizer_params"])
            optimizer = AGDMizer([target_vars], **o_params)
        else:
            raise NotImplementedError("Specified optimizer is not implemented")

        return optimizer

    def _finalize_optimizer(self, optimizer : Optimizer, closure :
                            Optional[Callable[[], Tensor]] = None):
        """Convenience method for doing final actions with the optimizer
        """
        if self.solver_params["optimizer"] == "agd":
            optimizer.set_final_iterate(closure)

    def _create_scheduler(self, optimizer : Optimizer):
        """Convenience method for creating different schedulers according to the
        scheduler mode which is specified.
        """
        scheduler = None
        if self.solver_params["lr_schedule_mode"] == "plateau":
            # Default plateau scheduler parameters
            scheduler_params = {'mode' : 'min', 'factor' : 0.5, 'patience' : 5,
                                'threshold' : self.solver_params["improvement_thresh"],
                                'cooldown' : 3}
            scheduler_params.update(self.solver_params["scheduler_params"])
            # Extra plateau scheduler parameters
            scheduler_params["verbose"] = self.solver_params["verbose"]

            scheduler = ReduceLROnPlateau(optimizer, **scheduler_params)

        elif self.solver_params["lr_schedule_mode"] == "sublinear":
            # Default plateau scheduler parameters, lr_lambda decreases the size
            # sublinearly
            scheduler_params = {'lr_lambda' : lambda epch: 1.0 / max(epch, 1.0)}
            scheduler_params.update(self.solver_params["scheduler_params"])

            scheduler = LambdaLR(optimizer, **scheduler_params)

        return scheduler

    def _update_lr(self, epoch : int, scheduler,
                   closure : Optional[Callable[[], Tensor]] = None):
        """Convenience method for updating the stepsize, using the given scheduler.
        """
        if scheduler is None: return
        elif self.solver_params["lr_schedule_mode"] == "plateau":
            scheduler.step(closure())
        elif self.solver_params["lr_schedule_mode"] == "sublinear":
            scheduler.step()

    def _get_max_reduc_num(self, initial_lr, min_lr, dec_fac):
        """Convenience method calculating the maximum number of line search reductions
        """

        if initial_lr < min_lr:
            raise ValueError("Initial step size is less than minimum step size.")
        elif dec_fac <= 0.0:
            raise ValueError("Step size decrease factor is non-positive.")

        log_init = torch.log(torch.as_tensor(initial_lr, dtype = torch.double))
        log_min = torch.log(torch.tensor(min_lr, dtype = torch.double))
        log_dec = torch.log(torch.tensor(dec_fac, dtype = torch.double))

        return ((log_init - log_min) / log_dec).floor_().item()

    # def _suff_dec_line_search(self, initial_lr, current_iterate,
    #                           reduction_factor, max_reductions,
    #                           sufficient_decrease, objective):
    #     """Convenience method to perform a sufficient decrease line search.
    #     """

    #     step_sz = initial_lr
    #     trial_iterate = torch.sub(current_iterate.detach(), current_iterate.grad, alpha = )
    #     curr_obj =

    #     curr_obj = objective(current_iterate)
    #     next_obj = float("inf")
    #     step_sz = default_lr

    #     suff_dec_val = suff_dec_factor * (prox_norm ** 2)
    #     num_reduc = 0
    #     while next_obj > curr_obj - suff_dec_val * step_sz:
    #     while

    #     return ((log_init - log_min) / log_dec).floor_().item()


    # @profile
    def _regularized_solve(self, **problem_dict) -> Tensor:
        """Solve using a proximal step and gradient condition stopping. Implemented separately to eliminate possible extra computation of gradients.

        No projection step has been implemented.
        """
        # Variables for optimization
        target_vars = problem_dict["init_val"].detach().clone().requires_grad_(True)
        # Used for performing calculations without gradient computations
        dtch_target_vars = target_vars.detach()

        # Maximum number of iterations to execute
        num_iter = problem_dict["num_iter"]

        # Step sizes to use
        default_lr = problem_dict.get("lr", self._default_lr)
        min_lr = problem_dict.get("min_lr", default_lr)

        # Calculate step sizes to use for line search
        suff_dec_factor = problem_dict.get("line-search-suf-dec-factor", 1e-4)
        ls_reduc_factor = problem_dict.get("line-search-reduc-factor", 2.0)
        ls_max_reduc = self._get_max_reduc_num(default_lr, min_lr, ls_reduc_factor)

        # Objective function
        objective = problem_dict["objective"]

        # Regularization, like squared euclidean norm
        regularizer = problem_dict["regularizer"]
        regularizer_grad = problem_dict.get("regularizer_grad", None)

        # Multiplicative threshold at which a smaller gradient norm of the
        # regularized problem, relative to the original problem will cause stopping
        stop_thresh = problem_dict.get("stopping_threshold", 1e-2)

        v = torch.zeros_like(dtch_target_vars)

        prev_obj = float("inf")
        starti = timer()
        # test_obj_vals = []
        for j in range(num_iter):
            # Get new gradients
            if target_vars.grad is not None: target_vars.grad.zero_()
            obj_val = objective(target_vars)
            obj_val.backward()
            # test_obj_vals.append(obj_val.detach().clone())
            if "solver_stats" in problem_dict:
                stats = problem_dict["solver_stats"]
                if "grad_evals" in stats:
                    stats["grad_evals"] += 1

            grad_norm = torch.linalg.norm(target_vars.grad)

            if regularizer_grad is not None:
                reg_val = regularizer(dtch_target_vars)
                target_vars.grad.add_(regularizer_grad(dtch_target_vars))
            else:
                reg_val = regularizer(target_vars)
                reg_val.backward()

            # Determine if early stopping should occur
            prox_norm = torch.linalg.norm(target_vars.grad)
            if prox_norm < max(stop_thresh * grad_norm, 1e-6):
                if "solver_stats" in problem_dict:
                    stats = problem_dict["solver_stats"]
                    if "early_stops" in stats:
                        stats["early_stops"] += 1
                if self.solver_params["verbose"]:
                    print("Gradient stopping criterion reached. " \
                          "Terminating early at epoch {}".format(j))
                break

            step_sz = default_lr
            next_iter = torch.sub(dtch_target_vars, target_vars.grad, alpha=step_sz)
            # Do line search if specified
            if ls_max_reduc >= 1:
                suff_dec_val = suff_dec_factor * (prox_norm ** 2)
                num_reduc = 0

                curr_obj = obj_val.detach() + reg_val.detach()
                next_obj = objective(next_iter) + regularizer(next_iter)
                while next_obj >= curr_obj - suff_dec_val * step_sz:
                    step_sz = step_sz / ls_reduc_factor
                    next_iter = torch.sub(dtch_target_vars, target_vars.grad,
                                          alpha=step_sz)
                    next_obj = objective(next_iter) + regularizer(next_iter)
                    num_reduc += 1
                    if num_reduc > ls_max_reduc:
                        if "solver_stats" in problem_dict:
                            stats = problem_dict["solver_stats"]
                            if "ls_fail" in stats:
                                stats["ls_fail"] += 1
                        if self.solver_params["verbose"]:
                            print("Minimum step size reached. Gradient may not be a" \
                                  " direction of sufficient descent")
                        return dtch_target_vars

            dtch_target_vars.copy_(next_iter)

            if self.solver_params.get("error_check", False):
                obj_check = obj_val.detach() + reg_val.detach()
                if obj_check > (prev_obj + 1e-7):
                    print("Error objective value increased")
                    breakpoint()

                if torch.isnan(dtch_target_vars).any():
                    print("Error nan encountered")
                    breakpoint()

                if torch.isinf(dtch_target_vars).any():
                    print("Error inf encountered")
                    breakpoint()

                prev_obj = obj_check

            # Accelerated/momentum step
            if self.solver_params["optimizer"] == "agd":
                v.sub_(target_vars.grad, alpha=default_lr)
                v.mul_((j+1)/(j + 4))
                dtch_target_vars.add_(v)

        endi = timer()

        # Do a final gradient step
        if self.solver_params["optimizer"] == "agd":
            if target_vars.grad is not None: target_vars.grad.zero_()
            obj_val = objective(target_vars) + regularizer(target_vars)
            obj_val.backward()

            # Gradient step
            dtch_target_vars.sub_(target_vars.grad, alpha=lr)

        if self.solver_params["verbose"]:
            print("Completed regularized solve in time: {:05.4f}".format(endi - starti))
            print("Final regularized objective value: {}".format(
                       objective(dtch_target_vars) +
                       regularizer(dtch_target_vars)))

        # print("LR: {:05.4f}".format(lr))

        return dtch_target_vars
