import os
import time

import jax.numpy as jnp
import jax.profiler
import numpy as np
import scipy.linalg
from jax import grad, jit, random
from jax.lax import transpose
from jax.numpy import float64
from pylanczos import PyLanczos

from environments import DIRECTIONALDERIVATIVE, FINITEDIFFERENCE, LEESELECTION, RANDOM
from utils.calculate import (
    get_hessian_with_hvp,
    get_hessian_with_hvp_2,
    get_jvp,
    get_minimum_eigenvalue,
    hessian,
    hvp,
    jax_randn,
    line_search,
    subspace_line_search,
)
from utils.logger import logger


class optimization_solver:
    def __init__(self, dtype=jnp.float64) -> None:
        self.f = None
        self.f_grad = None
        self.xk = None
        self.dtype = dtype
        self.backward_mode = True
        self.finish = False
        self.check_count = 0
        self.gradk_norm = None
        self.save_values = {}
        self.params_key = {}
        self.params = {}
        pass

    def __zeroth_order_oracle__(self, x):
        return self.f(x)

    def __first_order_oracle__(self, x, output_loss=False):
        x_grad = None
        if isinstance(self.backward_mode, str):
            if self.backward_mode == DIRECTIONALDERIVATIVE:
                x_grad = get_jvp(self.f, x, None)
            elif self.backward_mode == FINITEDIFFERENCE:
                dim = x.shape[0]
                d = np.zeros(dim, dtype=self.dtype)
                h = 1e-8
                e = np.zeros(dim, dtype=x.dtype)
                e[0] = 1
                e = jnp.array(e)
                z = self.f(x)
                for i in range(dim):
                    d[i] = (self.f(x + h * e) - z) / h
                x_grad = jnp.array(d)
            else:
                raise ValueError(f"{self.backward_mode} is not implemented.")
        elif self.backward_mode:
            x_grad = self.f_grad(x)

        if output_loss:
            return x_grad, self.f(x)
        else:
            return x_grad

    def __second_order_oracle__(self, x):
        return hessian(self.f)(x)

    def subspace_first_order_oracle(self, x, Mk):
        reduced_dim = Mk.shape[0]
        if isinstance(self.backward_mode, str):
            if self.backward_mode == DIRECTIONALDERIVATIVE:
                return get_jvp(self.f, x, Mk)
            elif self.backward_mode == FINITEDIFFERENCE:
                d = np.zeros(reduced_dim, dtype=self.dtype)
                h = 1e-8
                z = self.f(x)
                for i in range(reduced_dim):
                    d[i] = (self.f(x + h * Mk[i]) - z) / h
                return jnp.array(d)
        elif self.backward_mode:
            subspace_func = lambda d: self.f(x + Mk.T @ d)
            d = jnp.zeros(reduced_dim, dtype=self.dtype)
            return grad(subspace_func)(d)

    def subspace_second_order_oracle(self, x, Mk):
        if isinstance(self.backward_mode, str):
            if self.backward_mode == DIRECTIONALDERIVATIVE:
                if self.params["subspace_hessian_func"] == "loop":
                    return get_hessian_with_hvp(self.f, x, Mk)
                elif self.params["subspace_hessian_func"] == "hvp":
                    return get_hessian_with_hvp_2(self.f, x, Mk)
                else:
                    raise ValueError(
                        f"{self.params['subspace_hessian_func']} is not implemented!"
                    )
        else:
            reduced_dim = Mk.shape[0]
            d = jnp.zeros(reduced_dim, dtype=self.dtype)
            sub_func = lambda d: self.f(x + Mk.T @ d)
            return hessian(sub_func)(d)

    def __clear__(self):
        return

    def __run_init__(self, f, x0, iteration, params):
        self.f = f
        self.f_grad = grad(self.f)
        self.xk = x0.copy()
        self.save_values["func_values"] = np.zeros(iteration + 1)
        self.save_values["time"] = np.zeros(iteration + 1)
        self.save_values["grad_norm"] = np.zeros(iteration + 1)
        self.finish = False
        self.backward_mode = params["backward"]
        self.params = params
        self.check_count = 0
        self.save_values["func_values"][0] = self.f(self.xk)
        self.save_values["grad_norm"][0] = jnp.linalg.norm(self.f_grad(self.xk))

        if not hasattr(self.f, "is_stochastic"):
            self.f.is_stochastic = False
        if self.f.is_stochastic:
            self.save_values["full_batch_loss"] = np.zeros(iteration + 1)
            self.save_values["full_batch_loss"][0] = self.f.full_batch_loss(self.xk)

    def __check_params__(self, params):
        all_params = True
        assert len(self.params_key) == len(
            params
        ), "不要,または足りないparamがあります."
        for param_key in self.params_key:
            if param_key in params:
                continue
            else:
                all_params &= False

        assert all_params, "パラメータが一致しません"

    def check_norm(self, d, eps):
        d_norm = jnp.linalg.norm(d)
        self.update_save_values(self.check_count, grad_norm=d_norm)
        self.check_count += 1
        return d_norm <= eps

    def run(self, f, x0, iteration, params, save_path, log_interval=-1, max_time=None):
        self.__check_params__(params)
        self.__run_init__(f, x0, iteration, params)
        elapsed_time = 0
        if self.f.is_stochastic:
            logger.info("Run in a stochastic setting...")
        else:
            logger.info("Run in a determinic setting...")
        for i in range(iteration):
            self.__clear__()
            start_time = time.time()
            if not self.finish:
                self.__iter_per__()
            else:
                logger.info("Stop Criterion")
                break
            elapsed_time += time.time() - start_time
            if max_time is not None and max_time < elapsed_time:
                self.finish = True
                logger.info("Max time")
            F = self.f(self.xk)
            self.update_save_values(i + 1, time=elapsed_time, func_values=F)
            if self.f.is_stochastic:
                full_batch_loss = self.f.full_batch_loss(self.xk)
                self.update_save_values(i + 1, full_batch_loss=full_batch_loss)
                self.f.update_batch_start_index()
            if (i + 1) % log_interval == 0 and log_interval != -1:
                logger.info(f'{i+1}: {self.save_values["func_values"][i+1]}')
                self.save_results(save_path)
        return

    def update_save_values(self, iter, **kwargs):
        for k, v in kwargs.items():
            self.save_values[k][iter] = v

    def save_results(self, save_path):
        for k, v in self.save_values.items():
            jnp.save(os.path.join(save_path, k + ".npy"), v)
        jnp.save(os.path.join(save_path, "last_x.npy"), self.xk)

    def __update__(self, d):
        self.xk += d

    def __iter_per__(self):
        return

    def __direction__(self, grad):
        return

    def __step_size__(self):
        return


# first order method
class GradientDescent(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = ["lr", "eps", "backward", "linesearch"]

    def __iter_per__(self):
        grad = self.__first_order_oracle__(self.xk)
        if self.check_norm(grad, self.params["eps"]):
            self.finish = True
            return
        d = self.__direction__(grad)
        alpha = self.__step_size__(d, grad)
        self.__update__(alpha * d)
        return

    def __direction__(self, grad):
        return -grad

    def __step_size__(self, direction, grad):
        if self.params["linesearch"]:
            return line_search(
                xk=self.xk, func=self.f, grad=grad, dk=direction, alpha=0.3, beta=0.8
            )
        else:
            return self.params["lr"]


class SubspaceGD(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "lr",
            "reduced_dim",
            "mode",
            "eps",
            "backward",
            "linesearch",
            "random_matrix_seed",
        ]

    def __iter_per__(self):
        reduced_dim = self.params["reduced_dim"]
        mode = self.params["mode"]
        Mk = self.generate_matrix(self.dim, reduced_dim, mode)
        projected_grad = self.subspace_first_order_oracle(self.xk, Mk)
        if self.check_norm(projected_grad, self.params["eps"]):
            self.finish = True
        d = self.__direction__(projected_grad, Mk)
        alpha = self.__step_size__(direction=d, projected_grad=projected_grad, Mk=Mk)
        self.__update__(alpha * Mk.T @ d)

    def __step_size__(self, direction, projected_grad, Mk):
        if self.params["linesearch"]:
            return subspace_line_search(
                xk=self.xk,
                func=self.f,
                projected_grad=projected_grad,
                dk=direction,
                Mk=Mk,
                alpha=0.3,
                beta=0.8,
            )
        else:
            return self.params["lr"]

    def __run_init__(self, f, x0, iteration, params):
        super().__run_init__(f, x0, iteration, params)
        self.dim = x0.shape[0]
        self.key = random.PRNGKey(params["random_matrix_seed"])

    def __direction__(self, projected_grad, Mk):
        return -projected_grad

    def generate_matrix(self, dim, reduced_dim, mode):
        # (dim,reduced_dim)の行列を生成
        if mode == "random":
            P = random.normal(self.key, (reduced_dim, dim)).astype(self.dtype) / (
                reduced_dim**0.5
            )
            self.key = random.split(self.key)[0]
            return P
        elif mode == "identity":
            return None
        else:
            raise ValueError("No matrix mode")


class AcceleratedGD(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.lambda_k = 0
        self.yk = None
        self.params_key = ["lr", "eps", "backward", "restart"]

    def __run_init__(self, f, x0, iteration, params):
        self.yk = x0.copy()
        return super().__run_init__(f, x0, iteration, params)

    def __iter_per__(self):
        lr = self.params["lr"]
        lambda_k1 = (1 + (1 + 4 * self.lambda_k**2) ** (0.5)) / 2
        gamma_k = (1 - self.lambda_k) / lambda_k1
        grad = self.__first_order_oracle__(self.xk)
        if self.check_norm(grad, self.params["eps"]):
            self.finish = True
            return
        yk1 = self.xk - lr * grad
        xk1 = (1 - gamma_k) * yk1 + gamma_k * self.yk
        if self.params["restart"]:
            if self.f(xk1) > self.f(self.xk):
                self.lambda_k = 0
                return
        self.xk = xk1
        self.yk = yk1
        self.lambda_k = lambda_k1


class BFGS(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = ["alpha", "beta", "backward", "eps"]
        self.Hk = None
        self.gradk = None

    def __run_init__(self, f, x0, iteration, params):
        self.Hk = jnp.eye(x0.shape[0], dtype=self.dtype)
        super().__run_init__(f, x0, iteration, params)
        self.gradk = self.__first_order_oracle__(x0)
        self.check_norm(self.gradk, params["eps"])
        return

    def __direction__(self, grad):
        return -self.Hk @ grad

    def __step_size__(self, grad, dk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return line_search(self.xk, self.f, grad, dk, alpha, beta)

    def __iter_per__(self):
        dk = self.__direction__(self.gradk)
        s = self.__step_size__(grad=self.gradk, dk=dk)
        self.__update__(s * dk)
        gradk1 = self.__first_order_oracle__(self.xk)
        if self.check_norm(gradk1, self.params["eps"]):
            self.finish = True
            return
        yk = gradk1 - self.gradk
        self.update_BFGS(sk=s * dk, yk=yk)
        self.gradk = gradk1

    def update_BFGS(self, sk, yk):
        # a ~ 0
        a = sk @ yk
        if a < 1e-14:
            self.Hk = jnp.eye(sk.shape[0], dtype=self.dtype)
            return
        B = jnp.dot(jnp.expand_dims(self.Hk @ yk, 1), jnp.expand_dims(sk, 0))
        S = jnp.dot(jnp.expand_dims(sk, 1), jnp.expand_dims(sk, 0))
        self.Hk = self.Hk + (a + self.Hk @ yk @ yk) * S / (a**2) - (B + B.T) / a


class LimitedMemoryBFGS(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.s = None
        self.y = None
        self.r = None
        self.a = None
        self.b = None
        self.gradk = None
        self.params_key = ["alpha", "beta", "backward", "eps", "memory_size"]

    def __run_init__(self, f, x0, iteration, params):
        super().__run_init__(f, x0, iteration, params)
        m = params["memory_size"]
        dim = x0.shape[0]
        self.s = jnp.zeros((m, dim), dtype=self.dtype)
        self.y = jnp.zeros((m, dim), dtype=self.dtype)
        self.r = np.zeros(m, dtype=self.dtype)
        self.a = np.zeros(m, dtype=self.dtype)
        self.gradk = self.__first_order_oracle__(x0)
        self.check_norm(self.gradk, params["eps"])

    def __direction__(self, grad):
        g = grad
        memory_size = self.a.shape[0]
        param_reset = self.r[0] < 1e-14
        for i in range(memory_size):
            if param_reset:
                self.r[i] = 0
            if self.r[i] < 1e-14:
                self.a[i] = 0
                self.r[i] = 0
            else:
                self.a[i] = jnp.dot(self.s[i], g) / self.r[i]
                g -= self.a[i] * self.y[i]

        if param_reset:
            return -grad

        gamma = jnp.dot(self.s[0], self.y[0]) / jnp.dot(self.y[0], self.y[0])
        z = gamma * g
        for i in range(1, memory_size + 1):
            if self.r[-i] < 1e-14:
                continue
            else:
                b = jnp.dot(self.y[-i], z) / self.r[-i]
                z += self.s[-i] * (self.a[-i] - b)

        return -z

    def __iter_per__(self):
        dk = self.__direction__(self.gradk)
        s = self.__step_size__(grad=self.gradk, dk=dk)
        self.__update__(s * dk)
        gradk1 = self.__first_order_oracle__(self.xk)
        if self.check_norm(gradk1, self.params["eps"]):
            self.finish = True
            return
        yk = gradk1 - self.gradk
        self.update_BFGS(sk=s * dk, yk=yk)
        self.gradk = gradk1

    def __step_size__(self, grad, dk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return line_search(self.xk, self.f, grad, dk, alpha, beta)

    def update_BFGS(self, sk, yk):
        self.s = jnp.roll(self.s, 1, axis=0)
        self.y = jnp.roll(self.y, 1, axis=0)
        self.r = np.array(jnp.roll(self.r, 1))
        self.s = self.s.at[0].set(sk)
        self.y = self.y.at[0].set(yk)
        self.r[0] = jnp.dot(sk, yk)


class AcceleratedGDRestart(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.k = 0
        self.K = 0
        self.L = 0
        self.Sk = 0
        self.yk = None
        self.grad_xk = None
        self.grad_yk = None
        self.initial_loss = None
        self.params_key = ["L", "M", "alpha", "beta", "backward"]

    def __run_init__(self, f, x0, iteration, params):
        super().__run_init__(f, x0, iteration, params)
        self.yk = x0.copy()
        self.k = 0
        self.L = self.params["L"]
        self.Sk = 0
        self.Mk = self.params["M"]
        self.grad_xk = self.__first_order_oracle__(self.xk)
        self.grad_yk = self.grad_xk.copy()
        self.initial_loss = self.save_values["func_values"][0]

    def update_Sk(self, xk1, xk):
        self.Sk += jnp.linalg.norm(xk1 - xk) ** 2

    def __iter_per__(self):
        self.k += 1
        xk1 = self.yk - self.grad_yk / self.L
        yk1 = self.xk + self.k / (self.k + 1) * (xk1 - self.xk)
        self.update_Sk(xk1, self.xk)
        grad_xk1, loss_xk1 = self.__first_order_oracle__(xk1, output_loss=True)
        if loss_xk1 > self.initial_loss - self.L * self.Sk / (2 * (self.k + 1)):
            self.restart_iter(
                self.xk,
                self.params["alpha"] * self.L,
                grad_x0=self.grad_xk,
                grad_y0=self.grad_yk,
            )
            return

        grad_yk1, loss_yk1 = self.__first_order_oracle__(yk1, output_loss=True)
        self.update_Mk(
            loss_xk1=loss_xk1,
            loss_yk1=loss_yk1,
            grad_yk1=grad_yk1,
            grad_xk1=grad_xk1,
            grad_xk=self.grad_xk,
            xk1=xk1,
            xk=self.xk,
            yk1=yk1,
            theta_k=self.k / (self.k + 1),
        )
        if (self.k + 1) ** 5 * self.Mk**2 * self.Sk > self.L**2:
            self.restart_iter(
                xk1, self.params["beta"] * self.L, grad_x0=grad_xk1, grad_y0=grad_yk1
            )
            return
        self.xk = xk1
        self.yk = yk1
        self.grad_xk = grad_xk1
        self.grad_yk = grad_yk1

    def update_Mk(
        self, loss_xk1, loss_yk1, grad_yk1, grad_xk1, grad_xk, xk1, xk, yk1, theta_k
    ):
        a = (
            12
            * (loss_yk1 - loss_xk1 - 0.5 * jnp.dot(grad_yk1 + grad_xk1, yk1 - xk1))
            / (jnp.linalg.norm(yk1 - xk1) ** 3)
        )
        b = jnp.linalg.norm(grad_yk1 + theta_k * grad_xk - (1 + theta_k) * grad_xk1) / (
            theta_k * jnp.linalg.norm(xk1 - xk) ** 2
        )
        self.Mk = max(self.Mk, a, b)

    def restart_iter(self, x0, L, grad_x0, grad_y0):
        self.xk = x0.copy()
        self.yk = x0.copy()
        self.L = L
        self.k = 0
        self.grad_yk = grad_y0
        self.grad_xk = grad_x0


# prox(x,t):
class BacktrackingProximalGD(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.prox = None
        self.params_key = ["eps", "beta", "backward", "alpha"]

    def __run_init__(self, f, prox, x0, iteration, params):
        self.prox = prox
        return super().__run_init__(f, x0, iteration, params)

    def run(self, f, prox, x0, iteration, params, save_path, log_interval=-1):
        self.__check_params__(params)
        self.__run_init__(f, prox, x0, iteration, params)
        start_time = time.time()
        for i in range(iteration):
            self.__clear__()
            if not self.finish:
                self.__iter_per__()
            else:
                break
            self.save_values["time"][i + 1] = time.time() - start_time
            self.save_values["func_values"][i + 1] = self.f(self.xk)
            if (i + 1) % log_interval == 0 and log_interval != -1:
                logger.info(f'{i+1}: {self.save_values["func_values"][i+1]}')
                self.save_results(save_path)
        return

    def backtracking_with_prox(self, x, grad, beta, t=1, max_iter=10000, loss=None):
        if loss is None:
            loss = self.f(x)
        prox_x = self.prox(x - t * grad, t)
        while t * self.f(prox_x) > t * loss - t * grad @ (x - prox_x) + 1 / 2 * (
            (x - prox_x) @ (x - prox_x)
        ):
            t *= beta
            max_iter -= 1
            prox_x = self.prox(x - t * grad, t)
            if max_iter < 0:
                logger.info("Error: Backtracking is stopped because of max_iteration.")
                break
        return prox_x, t

    def __iter_per__(self):
        beta = self.params["beta"]
        eps = self.params["eps"]
        alpha = self.params["alpha"]
        grad, loss = self.__first_order_oracle__(self.xk, output_loss=True)
        prox_x, t = self.backtracking_with_prox(self.xk, grad, beta, t=alpha, loss=loss)
        if self.check_norm(self.xk - prox_x, t * eps):
            self.finish = True
        self.xk = prox_x.copy()
        return


class BacktrackingAcceleratedProximalGD(BacktrackingProximalGD):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.tk = 1
        self.vk = None
        self.k = 0
        self.xk1 = None
        self.params_key = ["restart", "beta", "eps", "backward", "alpha"]

    def run(self, f, prox, x0, iteration, params, save_path, log_interval=-1):
        self.tk = params["alpha"]
        return super().run(f, prox, x0, iteration, params, save_path, log_interval)

    def __run_init__(self, f, prox, x0, iteration, params):
        self.k = 0
        self.xk1 = x0.copy()
        return super().__run_init__(f, prox, x0, iteration, params)

    def __iter_per__(self):
        self.k += 1
        beta = self.params["beta"]
        eps = self.params["eps"]
        restart = self.params["restart"]
        k = self.k
        self.vk = self.xk + (k - 2) / (k + 1) * (self.xk - self.xk1)
        grad_v, loss_v = self.__first_order_oracle__(self.vk, output_loss=True)
        prox_x, t = self.backtracking_with_prox(self.xk, self.vk, grad_v, beta, loss_v)
        if self.check_norm(prox_x - self.vk, t * eps):
            self.finish = True
        self.xk1 = self.xk
        self.xk = prox_x.copy()
        self.v = None
        if restart:
            if self.f(self.xk) > self.f(self.xk1):
                self.k = 0

    def backtracking_with_prox(self, x, v, grad_v, beta, max_iter=10000, loss_v=None):
        if loss_v is None:
            loss_v = self.f(v)
        prox_x = self.prox(v - self.tk * grad_v, self.tk)
        while self.tk * self.f(prox_x) > self.tk * loss_v + self.tk * grad_v @ (
            prox_x - v
        ) + 1 / 2 * ((prox_x - v) @ (prox_x - v)):
            self.tk *= beta
            prox_x = self.prox(v - self.tk * grad_v, self.tk)
        return prox_x, self.tk


# second order method
class NewtonMethod(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = ["alpha", "beta", "eps", "backward"]

    def __iter_per__(self):
        grad = self.__first_order_oracle__(self.xk)
        if self.check_norm(grad, self.params["eps"]):
            self.finish = True
            return
        H = self.__second_order_oracle__(self.xk)
        dk = self.__direction__(grad=grad, hess=H)
        lr = self.__step_size__(grad=grad, dk=dk)
        self.__update__(lr * dk)

    def __direction__(self, grad, hess):
        return -jnp.linalg.solve(hess, grad)

    def __step_size__(self, grad, dk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return line_search(self.xk, self.f, grad, dk, alpha, beta)


class SubspaceNewton(SubspaceGD):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "dim",
            "reduced_dim",
            "mode",
            "backward",
            "alpha",
            "beta",
            "eps",
        ]

    def __iter_per__(self):
        reduced_dim = self.params["reduced_dim"]
        dim = self.params["dim"]
        mode = self.params["mode"]
        Mk = self.generate_matrix(dim, reduced_dim, mode)
        grad = self.subspace_first_order_oracle(self.xk, Mk)
        if self.check_norm(grad, self.params["eps"]):
            self.finish = True
            return
        H = self.subspace_second_order_oracle(self.xk, Mk)
        dk = self.__direction__(grad=grad, hess=H)
        lr = self.__step_size__(grad=grad, dk=dk, Mk=Mk)
        self.__update__(lr * Mk.T @ dk)

    def __direction__(self, grad, hess):
        return -jnp.linalg.solve(hess, grad)

    def __step_size__(self, grad, dk, Mk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return subspace_line_search(
            self.xk, self.f, projected_grad=grad, dk=dk, Mk=Mk, alpha=alpha, beta=beta
        )

    def generate_matrix(self, dim, reduced_dim, mode):
        # (dim,reduced_dim)の行列を生成
        if mode == "random":
            return jax_randn(reduced_dim, dim, dtype=self.dtype) / (reduced_dim**0.5)
        elif mode == "identity":
            return None
        else:
            raise ValueError("No matrix mode")


class LimitedMemoryNewton(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.Pk = None
        self.index = 0
        self.params_key = [
            "reduced_dim",
            "threshold_eigenvalue",
            "alpha",
            "beta",
            "backward",
            "mode",
            "eps",
        ]

    def generate_matrix(self, matrix_size, gk, mode):
        # P^\top = [x_0,\nabla f(x_0),...,x_k,\nabla f(x_k)]
        dim = self.xk.shape[0]
        if mode == LEESELECTION:
            if self.Pk is None:
                self.Pk = jnp.zeros((matrix_size, dim), dtype=self.dtype)
            self.Pk = self.Pk.at[self.index].set(self.xk)
            self.Pk = self.Pk.at[self.index + 1].set(gk)
            self.index += 2
            self.index %= matrix_size
        elif mode == RANDOM:
            self.Pk = jax_randn(matrix_size, dim, dtype=self.dtype) / matrix_size**0.5
        else:
            raise ValueError(f"{mode} is not implemented.")

    def subspace_second_order_oracle(self, x, Mk, threshold_eigenvalue):
        matrix_size = Mk.shape[0]
        H = super().subspace_second_order_oracle(x, Mk)
        sigma_m = get_minimum_eigenvalue(H)
        if sigma_m < threshold_eigenvalue:
            return H + (threshold_eigenvalue - sigma_m) * jnp.eye(
                matrix_size, dtype=self.dtype
            )
        else:
            return H

    def __iter_per__(self):
        matrix_size = self.params["reduced_dim"]
        threshold_eigenvalue = self.params["threshold_eigenvalue"]
        mode = self.params["mode"]
        gk = self.__first_order_oracle__(self.xk)
        self.generate_matrix(matrix_size, gk, mode)
        proj_gk = self.Pk @ gk
        if self.check_norm(gk, self.params["eps"]):
            self.finish = True
            return
        Hk = self.subspace_second_order_oracle(self.xk, self.Pk, threshold_eigenvalue)
        dk = self.__direction__(grad=proj_gk, hess=Hk)
        lr = self.__step_size__(grad=proj_gk, dk=dk, Mk=self.Pk)
        self.__update__(lr * self.Pk.T @ dk)

    def __direction__(self, grad, hess):
        return -jnp.linalg.solve(hess, grad)

    def __step_size__(self, grad, dk, Mk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return subspace_line_search(
            self.xk, self.f, projected_grad=grad, dk=dk, Mk=Mk, alpha=alpha, beta=beta
        )


class RandomizedBFGS(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.Bk_inv = None
        self.params_key = ["reduced_dim", "dim", "backward", "eps"]

    def run(self, f, x0, iteration, params, save_path, log_interval=-1):
        self.__run_init__(f, x0, iteration, params)
        self.__check_params__(params)
        self.backward_mode = params["backward"]
        start_time = time.time()
        for i in range(iteration):
            self.__clear__()
            if not self.finish:
                self.__iter_per__(params)
            else:
                logger.info("Stop Criterion")
                break
            T = time.time() - start_time
            F = self.f(self.xk)
            G = self.__first_order_oracle__(self.xk)
            self.update_save_values(i + 1, time=T, func_values=F, grad_norm=G)
            if (i + 1) % log_interval == 0 & log_interval != -1:
                logger.info(f'{i+1}: {self.save_values["func_values"][i+1]}')
                self.save_results(save_path)
        return

    def __run_init__(self, f, x0, iteration, params):
        self.Bk_inv = jnp.eye(x0.shape[0], dtype=self.dtype)
        return super().__run_init__(f, x0, iteration, params)

    def __iter_per__(self):
        reduced_dim = self.params["reduced_dim"]
        dim = self.params["dim"]
        grad, loss_k = self.__first_order_oracle__(self.xk, output_loss=True)
        if self.check_norm(grad, self.params["eps"]):
            self.finish = True
            return
        Hk = self.__second_order_oracle__(self.xk)
        self.__update__(-self.Bk_inv @ grad, loss_k)
        Sk = self.generate_matrix(reduced_dim=reduced_dim, dim=dim)
        self.update_rbfgs(Hk, Sk)

    def update_rbfgs(self, Hk, Sk):
        dim = Hk.shape[0]
        G = Sk @ jnp.linalg.solve(Sk.T @ Hk @ Sk, Sk.T)
        J = jnp.eye(dim, dtype=self.dtype) - G @ Hk
        self.Bk_inv = G - J @ self.Bk_inv @ J.T

    def __update__(self, d, loss_k):
        xk1 = self.xk + d
        loss_k1 = self.f(xk1)
        if loss_k1 < loss_k:
            self.xk = xk1
        return

    def generate_matrix(self, reduced_dim, dim):
        return jax_randn(dim, reduced_dim, dtype=self.dtype)


class SubspaceRNM(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "reduced_dim",
            "subspace_hessian_func",
            "gamma",
            "c1",
            "c2",
            "alpha",
            "beta",
            "eps",
            "backward",
            "random_matrix_seed",
        ]

    def __iter_per__(self):
        reduced_dim = self.params["reduced_dim"]
        dim = self.xk.shape[0]
        Pk = self.generate_matrix(dim, reduced_dim)
        subspece_H = self.subspace_second_order_oracle(self.xk, Pk)
        projected_grad = self.subspace_first_order_oracle(self.xk, Pk)
        if self.check_norm(projected_grad, self.params["eps"]):
            self.finish = True
            return
        l_min = get_minimum_eigenvalue(subspece_H)
        L = max(0, -l_min)
        Mk = (
            subspece_H
            + self.params["c1"] * L * jnp.eye(reduced_dim, dtype=self.dtype)
            + self.params["c2"]
            * jnp.linalg.norm(projected_grad) ** self.params["gamma"]
            * jnp.eye(reduced_dim, dtype=self.dtype)
        )
        dk = self.__direction__(projected_grad, Mk)
        s = self.__step_size__(projected_grad=projected_grad, Mk=Pk, dk=dk)
        self.__update__(s * Pk.T @ dk)

    def __direction__(self, projected_grad, subspace_H):
        # return -jnp.linalg.solve(subspace_H, projected_grad)
        subspace_H_inv = np.linalg.inv(subspace_H)
        return -jnp.dot(subspace_H_inv, projected_grad)

    def __step_size__(self, projected_grad, dk, Mk):
        alpha = self.params["alpha"]
        beta = self.params["beta"]
        return subspace_line_search(
            xk=self.xk,
            func=self.f,
            projected_grad=projected_grad,
            dk=dk,
            Mk=Mk,
            alpha=alpha,
            beta=beta,
        )

    def __run_init__(self, f, x0, iteration, params):
        super().__run_init__(f, x0, iteration, params)
        self.key = random.PRNGKey(params["random_matrix_seed"])

    def generate_matrix(self, dim, reduced_dim):
        P = random.normal(self.key, (reduced_dim, dim)).astype(self.dtype) / (
            reduced_dim**0.5
        )
        self.key = random.split(self.key)[0]
        return P


def eig_solve_scipy(
    g_sub: jnp.ndarray, H_sub: jnp.ndarray, delta: float
) -> jnp.ndarray:
    # n = g_sub.shape[0]
    # mat = jnp.zeros((n + 1, n + 1), dtype=self.dtype)
    # mat = mat.at[:n, :n].set(H_sub)
    # mat = mat.at[:n, n].set(g_sub)
    # mat = mat.at[n, :n].set(g_sub)
    # mat = mat.at[n, n].set(-delta)
    # vec_min = jnp.linalg.eigh(mat)[1][:, 0]
    # vec_min = jax.lax.linalg.eigh(mat, lower=False, symmetrize_input=False)[0][:, 0]
    mat = jnp.block([[H_sub, g_sub[:, None]], [g_sub[None, :], -delta]])
    vec_min = scipy.linalg.eigh(mat, subset_by_index=[0, 0])[1][:, 0]
    return vec_min


def eig_solve_jax(g_sub: jnp.ndarray, H_sub: jnp.ndarray, delta: float) -> jnp.ndarray:
    mat = jnp.block([[H_sub, g_sub[:, None]], [g_sub[None, :], -delta]])
    vec_min = jax.lax.linalg.eigh(mat, lower=False, symmetrize_input=False)[0][:, 0]
    return vec_min


def eig_solve_lanczos(
    g_sub: jnp.ndarray, H_sub: jnp.ndarray, delta: float
) -> jnp.ndarray:
    s = g_sub.shape[0]

    def matvec(v_in, v_out):
        v = v_in[:s]
        t = v_in[s]
        v_out[:s] = jnp.dot(H_sub, v) + t * g_sub
        v_out[s] = jnp.dot(g_sub, v) - t * delta

    return PyLanczos.create_custom(matvec, s + 1, "single", False, 1).run()[1][:, 0]


def eig_solve_lanczos_hvp_sub(
    obj,
    x: jnp.ndarray,
    P: jnp.ndarray,
    g_sub: jnp.ndarray,
    delta: float,
    precision="single",
) -> jnp.ndarray:
    s = P.shape[0]

    def matvec(v_in, v_out):
        v = v_in[:s]
        t = v_in[s]
        v_out[:s] = jnp.dot(P, hvp(obj, (x,), (jnp.dot(P.T, v),))) + t * g_sub
        v_out[s] = jnp.dot(g_sub, v) - t * delta

    return PyLanczos.create_custom(matvec, s + 1, precision, False, 1).run()[1][:, 0]
    # return PyLanczos.create_custom(matvec, s + 1, "single", False, 1).run()[1][:, 0]


def eig_solve_lanczos_hvp(
    obj, x: jnp.ndarray, g: jnp.ndarray, delta: float, precision="single"
) -> jnp.ndarray:
    n = x.shape[0]

    def matvec(v_in, v_out):
        v_in = v_in.astype(jnp.float64)
        v = v_in[:n]
        t = v_in[n]
        v_out[:n] = hvp(obj, (x,), (v,)) + t * g
        v_out[n] = jnp.dot(g, v) - t * delta

    return PyLanczos.create_custom(matvec, n + 1, precision, False, 1).run()[1][:, 0]


class SubspaceTRM(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "eig_solver",
            "subspace_hessian_func",
            "reduced_dim",
            "delta",
            "Delta",
            "nu",
            "alpha",
            "beta",
            "eps",
            "backward",
            "random_matrix_seed",
            "lanczos_precision",
        ]
        self.local = False

    def __iter_per__(self):
        reduced_dim = self.params["reduced_dim"]
        P = self.generate_matrix(self.xk.shape[0], reduced_dim)

        g_sub = self.subspace_first_order_oracle(self.xk, P)

        if self.check_norm(g_sub, self.params["eps"]):
            self.finish = True
            return

        # solve subproblem
        delta = 0 if self.local else self.params["delta"]
        eig_solver = self.params["eig_solver"]
        if eig_solver == "lanczos_hvp":
            subprob_sol = eig_solve_lanczos_hvp_sub(
                self.f, self.xk, P, g_sub, delta, self.params["lanczos_precision"]
            )
        else:
            H_sub = self.subspace_second_order_oracle(self.xk, P)
            if eig_solver == "lanczos":
                subprob_sol = eig_solve_lanczos(g_sub, H_sub, delta)
            elif eig_solver == "scipy":
                subprob_sol = eig_solve_scipy(g_sub, H_sub, delta)
            elif eig_solver == "jax":
                subprob_sol = eig_solve_jax(g_sub, H_sub, delta)
            else:
                raise ValueError(f"{eig_solver} is not a valid eigen solver.")

        v_sub = subprob_sol[:-1]
        t = subprob_sol[-1]

        # compute direction
        if abs(t) >= self.params["nu"]:
            d_sub = v_sub / t
            d = jnp.dot(P.T, d_sub)
        else:
            sign = 1 if -jnp.dot(g_sub, v_sub) > 0 else -1
            d_sub = sign * v_sub
            d = jnp.dot(P.T, d_sub)

        # update
        if (not self.local) and jnp.linalg.norm(d_sub) > self.params["Delta"]:
            eta = self.__step_size__(g_sub, d, d_sub)
            self.__update__(eta * d)
        else:
            self.__update__(d)
            self.local = True

    def __step_size__(self, g_sub, d, d_sub):
        alpha = self.params["alpha"]
        beta = self.params["beta"]

        # line search
        f_k = self.f(self.xk)
        stepsize = 1.0
        while f_k - self.f(self.xk + stepsize * d) < -alpha * stepsize * g_sub @ d_sub:
            stepsize *= beta
            if stepsize < 1e-12:
                return 0
        return stepsize

    def __run_init__(self, f, x0, iteration, params):
        super().__run_init__(f, x0, iteration, params)
        self.key = random.PRNGKey(params["random_matrix_seed"])

    def generate_matrix(self, dim, reduced_dim):
        P = random.normal(self.key, (reduced_dim, dim)).astype(self.dtype) / (
            reduced_dim**0.5
        )
        self.key = random.split(self.key)[0]
        return P


class HSODM(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "eig_solver",
            "delta",
            "Delta",
            "nu",
            "alpha",
            "beta",
            "eps",
            "backward",
            "lanczos_precision",
        ]
        self.local = False

    def __iter_per__(self):
        g = self.__first_order_oracle__(self.xk)

        if self.check_norm(g, self.params["eps"]):
            self.finish = True
            return

        # solve subproblem
        delta = 0 if self.local else self.params["delta"]
        eig_solver = self.params["eig_solver"]
        if eig_solver == "lanczos_hvp":
            subprob_sol = eig_solve_lanczos_hvp(
                self.f, self.xk, g, delta, self.params["lanczos_precision"]
            )
        else:
            H = self.__second_order_oracle__(self.xk)
            if eig_solver == "lanczos":
                subprob_sol = eig_solve_lanczos(g_sub, H_sub, delta)
            elif eig_solver == "scipy":
                subprob_sol = eig_solve_scipy(g_sub, H_sub, delta)
            elif eig_solver == "jax":
                subprob_sol = eig_solve_jax(g_sub, H_sub, delta)
            else:
                raise ValueError(f"{eig_solver} is not a valid eigen solver.")
        v = subprob_sol[:-1]
        t = subprob_sol[-1]

        # compute direction
        if abs(t) >= self.params["nu"]:
            d = v / t
        else:
            sign = 1 if -jnp.dot(g, v) > 0 else -1
            d = sign * v

        # update
        if (not self.local) and jnp.linalg.norm(d) > self.params["Delta"]:
            eta = self.__step_size__(g, d)
            self.__update__(eta * d)
        else:
            self.__update__(d)

    def __step_size__(self, g, d):
        alpha = self.params["alpha"]
        beta = self.params["beta"]

        # line search
        f_k = self.f(self.xk)
        stepsize = 1.0
        while f_k - self.f(self.xk + stepsize * d) < -alpha * stepsize * g @ d:
            stepsize *= beta
            if stepsize < 1e-12:
                return 0
        return stepsize


import os
import time

import jax.numpy as jnp
import numpy as np
from jax import grad, jit, random
from jax.lax import transpose
from jax.numpy import float64

# from jax.scipy.optimize import minimize_scalar
from scipy.optimize import root_scalar

from environments import DIRECTIONALDERIVATIVE, FINITEDIFFERENCE, LEESELECTION, RANDOM
from utils.calculate import (
    get_hessian_with_hvp,
    get_jvp,
    get_minimum_eigenvalue,
    hessian,
    hvp,
    jax_randn,
    line_search,
    subspace_line_search,
)
from utils.logger import logger


class CubicRegularizedNewtonLS(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = ["reg_coef", "eps", "solver_eps", "beta", "backward"]
        self.r0 = 0.1
        self.value = None
        self.reg_coef = None

    def __iter_per__(self):
        if self.value is None:
            self.value = self.f(self.xk)
        if self.reg_coef is None:
            self.reg_coef = self.params["reg_coef"]

        self.grad = self.__first_order_oracle__(self.xk)

        if self.check_norm(self.grad, self.params["eps"]):
            self.finish = True
            return

        self.hessian = self.__second_order_oracle__(self.xk)

        # Set the initial value of the regularization coefficient
        reg_coef = self.reg_coef * self.params["beta"]

        # Solve the cubic subproblem over the subspace
        s_new, _, r0_new, model_decrease = self.cubic_solver_root(
            self.grad, self.hessian, reg_coef, self.params["solver_eps"], self.r0
        )

        x_new = self.xk + s_new
        value_new = self.__zeroth_order_oracle__(x_new)

        while value_new > self.value - model_decrease:
            reg_coef = reg_coef / self.params["beta"]
            s_new, _, r0_new, model_decrease = self.cubic_solver_root(
                self.grad,
                self.hessian,
                reg_coef,
                epsilon=self.params["solver_eps"],
                r0=self.r0,
            )
            x_new = self.xk + s_new
            value_new = self.__zeroth_order_oracle__(x_new)

        self.xk = x_new
        self.reg_coef = reg_coef
        self.value = value_new
        self.r0 = r0_new

    def cubic_solver_root(self, g, H, M, epsilon=1e-8, r0=0.1):
        """
        Solve min_s <g, s> + 1/2<s, H s> + M/3 ||s||^3
        """

        def func(lam):
            s_lam = -jnp.linalg.solve(H + lam * jnp.eye(len(g)), g)
            return lam**2 - M**2 * jnp.linalg.norm(s_lam) ** 2

        def grad(lam):
            s_lam = -jnp.linalg.solve(H + lam * jnp.eye(len(g)), g)
            phi_lam_grad = -2 * jnp.dot(
                s_lam, jnp.linalg.solve(H + lam * jnp.eye(len(g)), s_lam)
            )
            return 2 * lam - M**2 * phi_lam_grad

        # Solve a 1-d nonlinear equation by Newton's method
        sol = root_scalar(
            func, fprime=grad, x0=r0, method="newton", maxiter=100, xtol=epsilon
        )
        r = sol.root
        s = -jnp.linalg.solve(H + r * jnp.eye(len(g)), g)
        norm_s = jnp.linalg.norm(s)
        model_decrease = r / 2 * norm_s**2 - M / 3 * norm_s**3 - jnp.dot(g, s) / 2
        return s, sol.iterations, r, model_decrease


class KrylovCubicRegularizedNewtonLS(optimization_solver):
    def __init__(self, dtype=jnp.float64) -> None:
        super().__init__(dtype)
        self.params_key = [
            "reg_coef",
            "reduced_dim",
            "eps",
            "solver_eps",
            "beta",
            "backward",
        ]
        self.r0 = 0.1
        self.value = None
        self.reg_coef = None

    def __iter_per__(self):
        if self.value is None:
            self.value = self.f(self.xk)
        if self.reg_coef is None:
            self.reg_coef = self.params["reg_coef"]

        self.grad = self.__first_order_oracle__(self.xk)

        if self.check_norm(self.grad, self.params["eps"]):
            self.finish = True
            return

        # Use Lanczos method to compute an orthogonal basis for the Krylov subspace
        hess_vec_prod = lambda v: hvp(self.f, (self.xk,), (v,))
        V, alphas, betas, _ = self.lanczos(
            hess_vec_prod, self.grad, self.params["reduced_dim"]
        )

        # The subspace Hessian
        H_sub = jnp.diag(alphas) + jnp.diag(betas, -1) + jnp.diag(betas, 1)

        # The subspace gradient
        e1 = jnp.zeros(len(alphas))
        e1 = e1.at[0].set(1)
        g_sub = jnp.linalg.norm(self.grad) * e1

        # Set the initial value of the regularization coefficient
        reg_coef = self.reg_coef * self.params["beta"]

        # Solve the cubic subproblem over the subspace
        s_new, _, r0_new, model_decrease = self.cubic_solver_root(
            g_sub, H_sub, reg_coef, self.params["solver_eps"], self.r0
        )

        x_new = self.xk + V @ s_new
        value_new = self.__zeroth_order_oracle__(x_new)

        # Backtracking line search
        iter_count = 0
        max_iter = 20
        while value_new > self.value - model_decrease and iter_count < max_iter:
            reg_coef = reg_coef / self.params["beta"]
            s_new, _, r0_new, model_decrease = self.cubic_solver_root(
                g_sub, H_sub, reg_coef, epsilon=self.params["solver_eps"], r0=self.r0
            )
            x_new = self.xk + V @ s_new
            value_new = self.__zeroth_order_oracle__(x_new)
            iter_count += 1

        self.xk = x_new
        self.reg_coef = reg_coef
        self.value = value_new
        self.r0 = r0_new

    def cubic_solver_root(self, g, H, M, epsilon=1e-8, r0=0.1):
        """
        Solve min_s <g, s> + 1/2<s, H s> + M/3 ||s||^3
        """

        def func(lam):
            s_lam = -jnp.linalg.solve(H + lam * jnp.eye(len(g)), g)
            return lam**2 - M**2 * jnp.linalg.norm(s_lam) ** 2

        def grad(lam):
            s_lam = -jnp.linalg.solve(H + lam * jnp.eye(len(g)), g)
            phi_lam_grad = -2 * jnp.dot(
                s_lam, jnp.linalg.solve(H + lam * jnp.eye(len(g)), s_lam)
            )
            return 2 * lam - M**2 * phi_lam_grad

        # Solve a 1-d nonlinear equation by Newton's method
        sol = root_scalar(
            func, fprime=grad, x0=r0, method="newton", maxiter=100, xtol=epsilon
        )
        r = sol.root
        s = -jnp.linalg.solve(H + r * jnp.eye(len(g)), g)
        norm_s = jnp.linalg.norm(s)
        model_decrease = r / 2 * norm_s**2 - M / 3 * norm_s**3 - jnp.dot(g, s) / 2
        return s, sol.iterations, r, model_decrease

    def lanczos(self, A, v, m):
        """
        Lanczos Method. The input A is an operator.
        """
        # Initialize beta and v
        beta = 0
        v_pre = jnp.zeros_like(v)

        # Normalize v
        v = v / jnp.linalg.norm(v)

        # Use V to store the Lanczos vectors
        V = jnp.zeros((len(v), m))
        V = V.at[:, 0].set(v)

        # Use alphas, betas to store the Lanczos parameters
        alphas = jnp.zeros(m)
        betas = jnp.zeros(m - 1)

        for j in range(m - 1):
            w = A(v) - beta * v_pre
            alpha = jnp.dot(v, w)
            alphas = alphas.at[j].set(alpha)
            w = w - alpha * v
            beta = jnp.linalg.norm(w)
            if jnp.abs(beta) < 1e-6:
                break
            betas = betas.at[j].set(beta)
            v_pre = v
            v = w / beta
            V = V.at[:, j + 1].set(v)

        if m > 1 and j < m - 2:
            V = V[:, : j + 1]
            alphas = alphas[: j + 1]
            betas = betas[:j]
        alphas = alphas.at[-1].set(jnp.dot(v, A(v)))

        return V, alphas, betas, beta
