#!/usr/bin/env python
import sys
import argparse
import random
import torch
import math
import numpy as np
import logging
import copy

from models.models import Module
from core.numpy_dataset import NumpyDataset
from learner.gradient_descent_learner import GradientDescentLearner


###############################################################################

class OptimizeGradientDescentLearner(GradientDescentLearner):

    def __init__(
        self, model, loss, criteria, optim, device,
        epoch=10, batch_size=None
    ):
        super().__init__(
            model, loss, criteria, optim, device,
            epoch=epoch, batch_size=batch_size)
        #  self.__param_list = None

    def _optimize(self, batch):
        self.model(batch)

        loss = self.loss(self.model.out, batch["y"])
        div_renyi = self.model.div_renyi
        div_rivasplata = self.model.div_rivasplata

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

        self._log = {"div_renyi": div_renyi, "div_rivasplata": div_rivasplata}

###############################################################################
