import math
from typing import Callable, Type, Dict, Any

import torch
from torch import Tensor
from torch.nn import SmoothL1Loss

from algorithms.convergence_algorithms.basic_config import FuncConfig
from algorithms.convergence_algorithms.convergence import ConvergenceAlgorithm
from algorithms.convergence_algorithms.typing import SizedDataset
from algorithms.nn.datasets import SinglePairPerPointDataset, PairsInEpsRangeDataset
from algorithms.nn.losses import GradientLoss
from algorithms.nn.trainer import train_gradient_network, step_model_with_gradient


class EGL(ConvergenceAlgorithm):
    ALGORITHM_NAME = "egl"

    def __init__(
        self,
        *args,
        grad_loss: Callable,
        database_size: int = 360,
        database_type: Type[SizedDataset] = SinglePairPerPointDataset,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.grad_loss = grad_loss
        self.database_type = database_type
        self.database_size = database_size

    @property
    def grad_network(self):
        return self.helper_network

    @property
    def grad_optimizer(self):
        return self.helper_optimizer

    def train_helper_model(
        self,
        samples: Tensor,
        samples_value: Tensor,
        num_of_minibatch: int,
        batch_size: int,
        exploration_size: int,
        epochs: int,
        new_samples_count: int,
    ):
        self.grad_network.train()
        mapped_evaluations = self.output_mapping.map(samples_value)
        dataset = self.database_type(
            database=samples,
            values=mapped_evaluations,
            exploration_size=exploration_size,
            epsilon=self.epsilon,
            new_samples=samples[-new_samples_count:],
            max_tuples=self.database_size,
            logger=self.logger,
        )
        self.logger.info(f"Created dataset with {len(dataset)}")
        losses = self.train_loops(epochs, batch_size, dataset)
        self.grad_network.eval()
        return losses

    def train_loops(self, epochs: int, batch_size: int, dataset: SizedDataset):
        total_loss = []
        for _ in range(epochs):
            losses = self.train_loop(min(self.max_batch_size, batch_size), dataset)
            total_loss += losses
        return total_loss

    def train_loop(self, batch_size: int, dataset: SizedDataset):
        taylor_loss = GradientLoss(
            self.grad_network, self.perturb * self.epsilon, self.calc_loss
        )
        return train_gradient_network(
            taylor_loss, self.grad_optimizer, dataset, batch_size, self.logger
        )

    def calc_loss(self, value: Tensor, target: Tensor) -> Tensor:
        return self.grad_loss(value, target)

    def train_model(self):
        self.model_to_train.train()
        self.grad_optimizer.zero_grad()
        self.grad_network.eval()

        model_to_train_gradient = self.grad_network(
            self.model_to_train.model_parameter_tensor()
        )
        # model_to_train_gradient = self.env.g_func(curr_point)
        self.logger.info(f"Gradient: {model_to_train_gradient}")
        model_to_train_gradient = model_to_train_gradient.to(device=self.device)
        self.logger.info(
            f"Algorithm {self.__class__.__name__} "
            f"moving Gradient size: {torch.norm(model_to_train_gradient)} on {self.env}"
        )

        step_model_with_gradient(
            self.model_to_train, model_to_train_gradient, self.model_to_train_optimizer
        )
        self.model_to_train.eval()

    def gradient(self, x) -> Tensor:
        training = self.grad_network.training
        self.grad_network.eval()
        model_to_train_gradient = self.grad_network(x.unsqueeze(0)).squeeze(0)
        self.grad_network.train(training)
        model_to_train_gradient[model_to_train_gradient != model_to_train_gradient] = 0
        return model_to_train_gradient

    @classmethod
    def object_default_values(cls) -> dict:
        return {"grad_loss": SmoothL1Loss()}

    @classmethod
    def _additional_configs(cls) -> Dict[str, Dict[str, Any]]:
        return {
            "train_norm": {
                "database_type": PairsInEpsRangeDataset,
                "helper_model_training_epochs": 1,
                "database_size": FuncConfig(
                    lambda dims, train_base=10_000, **kwargs: int(train_base)
                    * (math.floor(math.sqrt(dims)) * 2)
                ),
            },
            "train_norm_2": {
                "database_type": PairsInEpsRangeDataset,
                "helper_model_training_epochs": 1,
                "database_size": FuncConfig(
                    lambda dims, **kwargs: 20_000 * (math.floor(math.sqrt(dims)) * 2)
                ),
            },
            "train_log": {
                "database_type": PairsInEpsRangeDataset,
                "helper_model_training_epochs": 1,
                "database_size": FuncConfig(
                    lambda dims, train_base=10_000, **kwargs: max(
                        1, math.floor(math.log(2) ** 2)
                    )
                    * int(train_base)
                ),
            },
        }
