import torch

from algorithms.convergence_algorithms.egl import EGL
from algorithms.convergence_algorithms.typing import SizedDataset
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.losses import GradientLoss
from algorithms.nn.trainer import train_gradient_network


class TrainableEGL(EGL):
    ALGORITHM_NAME = "trainable_egl"

    def __init__(self, *args, test_ratio: float = 0.1, **kwargs):
        super(TrainableEGL, self).__init__(*args, **kwargs)
        self.test_ratio = test_ratio
        self.test_database = torch.tensor([], device=self.device)
        self.test_evaluations = torch.tensor([], device=self.device)

    def explore(self, exploration_size: int):
        test_database_size = int(self.num_of_batch_reply * exploration_size * self.test_ratio)
        test_samples, test_values = super().explore(
            max(int(exploration_size * self.test_ratio), 2)
        )
        self.test_database = torch.cat((self.test_database, test_samples))[-test_database_size:]
        self.test_evaluations = torch.cat((self.test_evaluations, test_values))[
            -test_database_size:
        ]
        return super().explore(exploration_size)

    def train_loop(self, batch_size: int, dataset: SizedDataset):
        mapped_evaluations = self.output_mapping.map(self.test_evaluations)
        testset = PairsInEpsRangeDataset(
            database=self.test_database,
            values=mapped_evaluations,
            epsilon=self.epsilon,
        )
        taylor_loss = GradientLoss(
            self.grad_network, self.perturb * self.epsilon, self.calc_loss
        )

        return train_gradient_network(
            taylor_loss,
            self.grad_optimizer,
            dataset,
            batch_size,
            self.logger,
            testset,
            self.max_batch_size,
        )

    def before_shrinking_hook(self):
        if self.input_mapping:
            self.test_database = self.input_mapping.inverse(self.test_database.detach())

    def after_shrinking_hook(self):
        if self.input_mapping:
            self.test_database = self.input_mapping.map(self.test_database.detach())

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "exploration_size": 8,
            "helper_model_training_epochs": 1,
            "database_type": "PairsInEpsRangeDataset",
            "database_size": 20_000,
            "min_trust_region_size": 0,
            "maximum_movement_for_shrink_type": 2,
        }
