import copy

from torch.optim import Adam

from algorithms.convergence_algorithms.egl import EGL


class EGLTrain(EGL):
    ALGORITHM_NAME = "egl_train"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.long_training_model = copy.deepcopy(self.helper_network)
        self.long_training_optimizer = Adam(self.long_training_model.parameters())

    def train_helper_model(
        self,
        samples,
        samples_value,
        num_of_minibatch,
        batch_size,
        exploration_size,
        epochs,
        new_samples_count,
    ):
        self.helper_network.train()
        mapped_evaluations = self.output_mapping.map(samples_value)
        dataset = self.database_type(
            database=samples,
            values=mapped_evaluations,
            exploration_size=exploration_size,
            epsilon=self.epsilon,
            new_samples=samples[-new_samples_count:],
            max_tuples=self.database_size.value,
        )
        self.logger.info(f"Created dataset with {len(dataset)}")
        losses = self.train_loops(epochs, batch_size, dataset)
        self.helper_network.eval()
        return losses
