from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
from algorithms.convergence_algorithms.typing import SizedDataset
from algorithms.nn.losses import NaturalHessianLoss
from algorithms.nn.trainer import train_gradient_network


class HEGL(EGLScheduler):
    ALGORITHM_NAME = "hegl"

    def train_loop(self, batch_size: int, dataset: SizedDataset):
        taylor_loss = NaturalHessianLoss(
            self.grad_network, self.perturb * self.epsilon, self.calc_loss
        )
        return train_gradient_network(
            taylor_loss, self.grad_optimizer, dataset, batch_size, self.logger
        )


class HEGLNorm(EGL):
    ALGORITHM_NAME = "hegl_norm"

    def train_loop(self, batch_size: int, dataset: SizedDataset):
        taylor_loss = NaturalHessianLoss(
            self.grad_network, self.perturb * self.epsilon, self.calc_loss
        )
        return train_gradient_network(
            taylor_loss, self.grad_optimizer, dataset, batch_size, self.logger
        )
