import torch
from eos_line_search.optimizers import LineSearch as LS
from eos_line_search.utils import *
import numpy as np
import sys


class PoNoS(LS.LineSearch):
    """
    PoNoS Arguments:
         c=0.5, # line search sufficient decrease scaling constant
         c_p=0.1, # Polyak step size scaling constant
         delta=0.5, # cutting step
         zhang_xi=1, # Zhang xi, controlling the nonmonotonicity
         max_eta=10, # maximum step size
         min_eta=1e-06, #minimum step size
         f_star=0, # estimate of the min value of f
         save_backtracks=True # activate the memory-based resetting technique

         Note that PoNoS is like LBFGS from the LBFGS optimizer from pytorch,
         the step needs to be called like in the following:
         closure = lambda: loss_function(model, images, labels, backwards=False)
         opt.step(closure)
    """

    def __init__(
        self,
        params,
        init_step_size=1,
        max_eta=10,
        c=0.5,
        beta=0.5,
        reset_option=0,
        forward_option=0,
        n_batches_per_epoch=500,
        gamma=2.0,
        min_eta=1e-06,
        c_p=0.1,
        save_backtracks=False,
        zhang_xi=1,
        f_star=0,
        eps=0,
        adapt_c=False,
        nonmonotone_option=0,
        M=10,
        adapt_M=0,
        cdat_optimizer=None,  # CDAT optimizer instance for reset_option=2
        malmis_optimizer=None,  # MalMis optimizer instance for reset_option=3
        num_classes=None,
    ):
        super().__init__(
            params,
            init_step_size=init_step_size,
            max_eta=max_eta,
            c=c,
            beta=beta,
            reset_option=reset_option,
            forward_option=forward_option,
            n_batches_per_epoch=n_batches_per_epoch,
            gamma=gamma,
            min_eta=min_eta,
            c_p=c_p,
            save_backtracks=save_backtracks,
            eps=eps,
            cdat_optimizer=cdat_optimizer,
            malmis_optimizer=malmis_optimizer,
        )
        self.zhang_xi = zhang_xi
        self.state["Q_k"] = 0
        self.state["C_k"] = 0
        self.f_star = f_star
        self.adapt_c = adapt_c
        self.nonmonotone_option = nonmonotone_option
        self.M = M
        self.adapt_M = adapt_M
        # optional: number of classes for special power-of-two behavior
        self.num_classes = num_classes

        if self.adapt_c == 2 or self.adapt_c == 3:
            self.L_past = []
            self.L_avgs_past = []

        if (
            self.nonmonotone_option == 1
            or self.nonmonotone_option == 2
            or self.nonmonotone_option == 3
            or self.nonmonotone_option == 4
            or self.nonmonotone_option == 5
        ):
            self.L_past = []
            self.fvals_past = []

        self.last_a = 10
        self.fast_window_size = 5
        self.slow_window_size = 50
        self.list_of_a = []
        self.orig_c = self.c
        self.orig_xi = zhang_xi
        self.whithin_n_it = np.inf

    def _perform_armijo_line_search(
        self,
        step_size,
        params_current,
        grad_current,
        loss,
        closure_deterministic,
        suff_dec,
        ref_value,
    ):
        """
        Helper method to perform Armijo line search.
        Returns: (found, step_size, a, e)
        """
        for e in range(100):
            LS.LineSearch.gd_update(
                self, self.params, step_size, params_current, grad_current
            )
            loss_next = closure_deterministic()
            self.state["func_evals"] += 1
            found, step_size, _, a = LS.LineSearch.check_armijo_conditions(
                self=self,
                step_size=step_size,
                loss=loss.item(),
                suff_dec=suff_dec,
                loss_next=loss_next.item(),
                c=self.c,
                beta=self.beta,
                ref_value=ref_value,
                eps=self.eps,
            )
            if found == 1:
                break
        return found, step_size, a, e

    def _append_to_window(self, lst, val, window_size):
        """
        Helper to append value to a bounded window list.
        """
        if len(lst) < window_size:
            lst.append(val)
        else:
            lst.pop(0)
            lst.append(val)

    def _handle_forward_options(
        self,
        step_size,
        params_current,
        grad_current,
        loss,
        closure_deterministic,
        suff_dec,
        ref_value,
    ):
        """
        Helper method to handle forward_option selection techniques.
        Returns: (step_size, found, a, e)
        """
        a = None
        e = 0
        found = 1  # default: assume line search succeeded

        if self.state["backtracks"] == 0 and (
            self.forward_option == 7 or self.forward_option == 9
        ):
            step_size = (1 / self.beta) * step_size
            found, step_size, a, e = self._perform_armijo_line_search(
                step_size,
                params_current,
                grad_current,
                loss,
                closure_deterministic,
                suff_dec,
                ref_value,
            )

        elif self.forward_option == 8 or self.forward_option == 10:
            while self.state["backtracks"] == 0:
                step_size = (1 / self.beta) * step_size
                found, step_size, a, e = self._perform_armijo_line_search(
                    step_size,
                    params_current,
                    grad_current,
                    loss,
                    closure_deterministic,
                    suff_dec,
                    ref_value,
                )

        elif self.forward_option == 11:
            while (
                self.state["backtracks"] == 0
                and ((1 / self.beta) * step_size) < self.num_classes
            ):
                step_size = (1 / self.beta) * step_size
                found, step_size, a, e = self._perform_armijo_line_search(
                    step_size,
                    params_current,
                    grad_current,
                    loss,
                    closure_deterministic,
                    suff_dec,
                    ref_value,
                )
                if found == 1:
                    break
        elif self.forward_option == 12:
            while (
                self.state["backtracks"] == 0
                and ((1 / self.beta) * step_size) < self.num_classes
            ):
                step_size = (1 / self.beta) * step_size
                found, step_size, a, e = self._perform_armijo_line_search(
                    step_size,
                    params_current,
                    grad_current,
                    loss,
                    closure_deterministic,
                    suff_dec,
                    ref_value,
                )

        return step_size, found, a, e

    def line_search(
        self,
        step_size,
        params_current,
        grad_current,
        loss,
        closure_deterministic,
        grad_norm,
        sharpness,
        iteration,
    ):
        with torch.no_grad():

            grad_norm = maybe_torch(grad_norm)

            if grad_norm >= 1e-8 and loss.item() >= 1e-8:
                # check if condition is satisfied
                found = 0
                suff_dec = grad_norm**2

                # adapt c
                if self.adapt_c == 1:
                    self.c = max(1 - step_size / 2, 1e-3)
                elif self.adapt_c == 2 or self.adapt_c == 3:
                    if iteration == 0:
                        self.L_past.append(sharpness)
                        self.L_avgs_past.append(sharpness)
                    elif iteration == 1:
                        self.L_past.append(sharpness)
                        L_avg = sum(self.L_past) / len(self.L_past)
                        self.L_avgs_past.append(L_avg)
                    else:
                        self.L_past.append(sharpness)
                        L_avg = sum(self.L_past) / len(self.L_past)
                        self.L_avgs_past.append(L_avg)
                        if L_avg > self.L_avgs_past[0] and L_avg > self.L_avgs_past[1]:
                            self.c = self.c * 0.95
                        elif (
                            self.adapt_c == 3
                            and L_avg < self.L_avgs_past[0]
                            and L_avg < self.L_avgs_past[1]
                        ):
                            self.c = self.c * 1.05
                        self.L_past.pop(0)
                        self.L_avgs_past.pop(0)
                elif self.adapt_c == 4:
                    if self.last_a - 0.2 < 0.002:
                        self.c = 0.9
                    else:
                        self.c = self.orig_c

                elif self.adapt_c == 5:
                    if (
                        len(self.list_of_a) >= self.fast_window_size
                        and is_decreasing_fast(self.list_of_a[-self.fast_window_size :])
                        and self.state["grad_evals"] < 50
                    ):
                        print("Yes, it is decreasing fast")
                        self.zhang_xi = 0.1
                        self.c = 0.01
                        self.whithin_n_it = 0
                    elif len(
                        self.list_of_a
                    ) >= self.slow_window_size and is_increasing_slowly(
                        self.list_of_a[-self.slow_window_size :]
                    ):
                        print("No, it is increasing slowly")
                        self.whithin_n_it += 1
                        self.zhang_xi = min(
                            self.orig_xi + (self.whithin_n_it * 0.001), 0.999
                        )
                        self.c = self.orig_c
                        self.M += 1
                    else:
                        self.whithin_n_it += 1
                        if self.whithin_n_it < 50:
                            self.zhang_xi = 0.1
                            self.c = 0.01
                        else:
                            self.zhang_xi = min(
                                self.orig_xi + (self.whithin_n_it * 0.001), 0.999
                            )
                            #                            self.zhang_xi = self.orig_xi
                            self.c = self.orig_c
                # select Hager, Grippo, or other technique for non-monotone line search
                if self.nonmonotone_option == 0:
                    # compute nonmonotone terms for the Zhang & Hager line search
                    q_kplus1 = self.zhang_xi * self.state["Q_k"] + 1
                    self.state["C_k"] = (
                        self.zhang_xi * self.state["Q_k"] * self.state["C_k"]
                        + loss.item()
                    ) / q_kplus1

                    self.state["Q_k"] = q_kplus1
                    ref_value = max(self.state["C_k"], loss.item())

                elif self.nonmonotone_option == 1:
                    # compute nonmonotone terms using Grippo line search
                    self._append_to_window(self.fvals_past, loss.item(), self.M)
                    ref_value = max(self.fvals_past)
                    #                    wandb.log({"M": self.M}, commit=False)

                    # choose method to adapt window of line search
                    if self.adapt_M == 0:
                        pass
                    elif self.adapt_M == 1:
                        if len(self.L_past) < 3:
                            self.L_past.append(sharpness)
                        else:
                            self.L_past.pop(0)
                            self.L_past.append(sharpness)
                        if len(self.L_past) == 3:
                            if (
                                self.L_past[2] < self.L_past[1]
                                and self.L_past[1] < self.L_past[0]
                                and self.M > 1
                            ):
                                self.M = self.M - 1
                                if len(self.fvals_past) > self.M:
                                    self.fvals_past.pop(0)
                            elif (
                                self.L_past[2] > self.L_past[1]
                                and self.L_past[1] > self.L_past[0]
                            ):
                                self.M = self.M + 1
                    elif self.adapt_M == 2:
                        grad_norm_inf = compute_grad_inf_norm(self.params)
                        wandb.log({"Grad Norm Inf": grad_norm_inf}, commit=False)

                        if grad_norm_inf > 1e-1 and self.M > 1:
                            self.M = self.M - 1
                        elif grad_norm_inf < 1e-1 and grad_norm_inf > 1e-3:
                            self.M = self.M + 1

                elif self.nonmonotone_option == 2:
                    self._append_to_window(self.fvals_past, loss.item(), self.M)
                    ref_value = max(self.fvals_past)
                    wandb.log({"M": self.M}, commit=False)

                    xi = 1 - torch.exp(torch.tensor(1)) ** (-grad_norm)
                    ref_value = xi * ref_value + (1 - xi) * loss
                    wandb.log({"xi": xi}, commit=False)

                elif self.nonmonotone_option == 3:
                    # compute nonmonotone terms using average of previous values of window M
                    self._append_to_window(self.fvals_past, loss.item(), self.M)
                    avg = sum(self.fvals_past) / len(self.fvals_past)
                    ref_value = max(avg, loss.item())
                    wandb.log({"M": self.M}, commit=False)

                elif self.nonmonotone_option == 4:
                    # compute nonmonotone terms using percentile Grippo line search
                    self._append_to_window(self.fvals_past, loss.item(), self.M)
                    ref_value = np.percentile(self.fvals_past, self.zhang_xi * 100)

                elif self.nonmonotone_option == 5:
                    # compute nonmonotone terms using percentile Grippo line search (this is the same as option 4)
                    self.fvals_past.append(loss.item())
                    ref_value = np.percentile(
                        self.fvals_past[-self.M :], self.zhang_xi * 100
                    )

                # perform line search
                found, step_size, a, e = self._perform_armijo_line_search(
                    step_size,
                    params_current,
                    grad_current,
                    loss,
                    closure_deterministic,
                    suff_dec,
                    ref_value,
                )

                ### new forward step selection techniques
                step_size, found, fwd_a, fwd_e = self._handle_forward_options(
                    step_size,
                    params_current,
                    grad_current,
                    loss,
                    closure_deterministic,
                    suff_dec,
                    ref_value,
                )
                # use forward option 'a' if available, else use base line search 'a'
                if fwd_a is not None:
                    a = fwd_a
                    e = fwd_e

                ###

                # wandb.log({"c": self.c}, commit=False)

                # if line search exceeds 100 internal iterations
                if found == 0:
                    step_size = 1e-6
                    LS.LineSearch.gd_update(
                        self, self.params, step_size, params_current, grad_current
                    )

                self.lk = max(self.lk + e - 1, 0)

            else:
                a = 0
                print("Grad norm is {} and loss is {}".format(grad_norm, loss.item()))

        self.last_a = a
        self.list_of_a.append(a)

        return step_size, self.state["backtracks"], self.state["func_evals"], a
