from abc import ABC
from typing import Dict, Any

import torch
from scipy.optimize import minimize
from torch import Tensor
from torch.optim import Optimizer, Adam, SGD

from algorithms.convergence_algorithms.basic_config import FuncConfig
from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.typing import SizedDataset
from algorithms.nn.grad import hessian_from_gradient_network
from algorithms.nn.losses import NaturalHessianLoss, GradientLoss
from algorithms.nn.modules import ConfigurableModule, BasicNetwork
from algorithms.nn.trainer import train_gradient_network, step_model_with_gradient


class NewtonHEGL(EGL, ABC):
    def __init__(
        self,
        *args,
        hessian_network: ConfigurableModule,
        hessian_optimizer: Optimizer,
        hessian_inverse_normalizer: float = 1e-1,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.hessian_network = hessian_network
        self.hessian_optimizer = hessian_optimizer
        self.hessian_inverse_normalizer = hessian_inverse_normalizer

    def calc_hessian(self, x: Tensor = None):
        return hessian_from_gradient_network(
            self.hessian_network,
            self.model_to_train.model_parameter_tensor() if x is None else x,
        )

    def train_loop(self, batch_size: int, dataset: SizedDataset):
        hessian_loss = NaturalHessianLoss(
            self.hessian_network, self.perturb * self.epsilon, self.calc_loss
        )
        gradient_loss = GradientLoss(
            self.grad_network, self.perturb * self.epsilon, self.calc_loss
        )
        train_gradient_network(
            gradient_loss, self.grad_optimizer, dataset, batch_size, self.logger
        )

        return train_gradient_network(
            hessian_loss, self.hessian_optimizer, dataset, batch_size, self.logger
        )

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "hessian_optimizer": FuncConfig(
                lambda helper_network, **kwargs: Adam(
                    helper_network.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
                )
            )
        }

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"hessian_network": BasicNetwork}


# Note: The model to train optimizer has no use in this algorithm
class NewtonEGL(NewtonHEGL):
    ALGORITHM_NAME = "newton_egl"

    def train_model(self):
        self.model_to_train.train()
        self.hessian_network.train()
        self.model_to_train_optimizer.zero_grad()
        self.hessian_optimizer.zero_grad()
        self.grad_optimizer.zero_grad()
        self.grad_network.eval()

        hessian = self.calc_hessian()
        # hessian = hessian + torch.eye(hessian.shape[0], device=self.device) * 1e-1
        u, s, v = torch.linalg.svd(hessian)
        s = torch.clip(s, min=self.hessian_inverse_normalizer)
        hessian = torch.linalg.multi_dot([v, torch.diag(s), u.T])
        # hessian = self.env.h_func(self.model_to_train.model_parameter_tensor().detach().cpu()).to(device=self.device)
        hessian_inverse = torch.inverse(hessian)  # .to(device=self.device)
        curr_point = self.model_to_train.model_parameter_tensor().detach()
        model_to_train_gradient = self.grad_network(curr_point)
        # model_to_train_gradient = self.env.g_func(curr_point.cpu()).to(device=self.device)
        model_to_train_gradient[model_to_train_gradient != model_to_train_gradient] = 0
        step = -(hessian_inverse @ model_to_train_gradient)
        self.logger.info(
            f"Algorithm {self.__class__.__name__} moving Gradient size: {torch.norm(step)} on {self.env}"
        )
        # Update the gradient
        step_model_with_gradient(
            self.model_to_train, model_to_train_gradient, self.model_to_train_optimizer
        )
        self.model_to_train.eval()

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "model_to_train_optimizer": FuncConfig(
                lambda model_to_train, **kwargs: SGD(model_to_train.parameters(), lr=1)
            ),
        }

    @classmethod
    def _additional_configs(cls) -> Dict[str, Dict[str, Any]]:
        return {
            "momentum": {
                "model_to_train_optimizer": FuncConfig(
                    lambda model_to_train, **kwargs: Adam(model_to_train.parameters(), lr=0.01)
                )
            }
        }


class NewtonGHEGL(NewtonHEGL):
    ALGORITHM_NAME = "newton_cg_hegl"

    # TODO - should I stop to shrink?
    def train_model(self):
        def hess(x):
            hessian = self.calc_hessian(torch.from_numpy(x).to(device=self.device))
            return hessian.detach().cpu().numpy()

        def gradient(x):
            grad = self.grad_network(torch.from_numpy(x).to(device=self.device))
            return grad.detach().cpu().numpy()

        curr_point = self.model_to_train.model_parameter_tensor()
        res = minimize(
            self.env,
            curr_point.detach().cpu().numpy(),
            method="Newton-CG",
            jac=gradient,
            hess=hess,
        )
        new_point = torch.from_numpy(res.x)
        self.model_to_train.from_parameter_tensor(new_point)
