from typing import Dict

from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import MultiplicativeLR

from algorithms.convergence_algorithms.egl import EGL
from algorithms.nn.datasets import PairFromDistributionDataset, TuplesDataset
from algorithms.nn.distributions import WeightsDistributionBase, QuantileWeights
from algorithms.nn.losses import loss_with_quantile


class EGLScheduler(EGL):
    ALGORITHM_NAME = "egl_scheduler"

    def __init__(
        self,
        *args,
        model_to_train_optimizer: Optimizer,
        train_quantile: int = 83,
        model_lr_factor: float = None,
        weights_creator: WeightsDistributionBase,
        dist_sample: bool = False,
        **kwargs
    ):
        super().__init__(*args, model_to_train_optimizer=model_to_train_optimizer, **kwargs)
        self.train_quantile = train_quantile
        self.weights_creator = weights_creator
        self.value_scheduler = (
            MultiplicativeLR(model_to_train_optimizer, lambda epoch: model_lr_factor)
            if model_lr_factor
            else None
        )
        self.dist_sample = dist_sample

    def after_shrinking_hook(self):
        super().after_shrinking_hook()
        if self.value_scheduler:
            self.value_scheduler.step()

    def train_loop(self, batch_size: int, dataset: TuplesDataset):
        if self.dist_sample:
            return super().train_loop(batch_size, PairFromDistributionDataset(dataset))
        else:
            return super().train_loop(batch_size, dataset)

    def calc_loss(self, value: Tensor, target: Tensor) -> Tensor:
        return loss_with_quantile(
            value, target, self.train_quantile, self.weights_creator, self.grad_loss
        )

    @classmethod
    def object_default_values(cls) -> dict:
        return {
            "exploration_size": 8,
            "epsilon": 0.8,
            "min_trust_region_size": 0,
        }

    @classmethod
    def _default_types(cls) -> Dict[str, type]:
        return {"weights_creator": QuantileWeights}
