from typing import Callable, Dict

import numpy as np
from torch import Tensor
from torch.nn import SmoothL1Loss
from tqdm.auto import trange, tqdm

from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.nn.modules import BasicSurrogateModel


class IGL(ConvergenceAlgorithm):
    ALGORITHM_NAME = "igl"

    def __init__(self, *args, loss: Callable, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss = loss

    @property
    def value_network(self):
        return self.helper_network

    @property
    def value_optimizer(self):
        return self.helper_optimizer

    def train_helper_model(
        self,
        samples: Tensor,
        samples_value: Tensor,
        num_of_minibatch: int,
        batch_size: int,
        exploration_size: int,
        epochs: int,
        new_samples_count: int,
    ):
        self.value_network.train()
        mapped_evaluations = self.output_mapping.map(samples_value)
        len_replay_buffer = len(mapped_evaluations)
        for _ in trange(
            epochs, leave=False, desc=f"Training the gradient network {epochs} loops"
        ):
            i_indexes = np.random.choice(
                len_replay_buffer, (num_of_minibatch, batch_size)
            )

            for i_index in tqdm(i_indexes, leave=False):
                x_i = samples[i_index]
                y_i = mapped_evaluations[i_index]

                self.value_optimizer.zero_grad()
                self.model_to_train_optimizer.zero_grad()
                predicted_value = self.value_network(x_i)

                loss = self.loss(predicted_value.squeeze(), y_i)
                loss.backward()
                self.value_optimizer.step()
        self.value_network.eval()

    def train_model(self):
        self.model_to_train_optimizer.zero_grad()
        loss = self.value_network(self.model_to_train.model_parameter_tensor())
        loss.backward()
        self.model_to_train_optimizer.step()
        self.logger.info(
            f"Algorithm {self.__class__.__name__} updated after loss {loss}"
        )

    def gradient(self, x) -> Tensor:
        self.model_to_train_optimizer.zero_grad()
        loss = self.value_network(x)
        loss.backward()
        return self.model_to_train.model_parameter_tensor().grad

    @classmethod
    def object_default_values(cls) -> dict:
        return {"loss": SmoothL1Loss()}

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"helper_network": BasicSurrogateModel}

