import functools
from typing import List

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
from algorithms.convergence_algorithms.utils import ball_perturb_between_radius
from algorithms.nn.datasets import NewPairEpsDataset
from algorithms.nn.losses import loss_with_quantile, GradientLoss
from algorithms.nn.trainer import step_model_with_gradient, train_gradient_network


class MultipleEpsEGL(EGLScheduler):
    ALGORITHM_NAME = "multiple_eps_egl"

    def __init__(
        self,
        *args,
        gradient_networks: List[Module],
        gradient_optimizers: List[Optimizer],
        train_quantiles: List[int],
        epsilons: List[float] = None,
        **kwargs,
    ):
        super().__init__(
            *args,
            value_optimizer=gradient_optimizers[0],
            helper_network=gradient_networks[0],
            epsilon=epsilons[0],
            **kwargs,
        )
        self.train_quantiles = train_quantiles
        self.gradient_networks = gradient_networks
        self.gradient_optimizers = gradient_optimizers
        self.epsilons = epsilons

    def samples_points(self, base_point: Tensor, exploration_size: int):
        self.logger.info(
            f"Exploring new data points. Sampling {exploration_size} for {len(self.epsilons)} epsilons"
        )
        new_model_samples = torch.stack(
            [
                ball_perturb_between_radius(
                    base_point,
                    self.epsilons[i - 1] if i != 0 else 0,
                    self.epsilons[i],
                    exploration_size,
                    self.device,
                )
                for i in range(len(self.epsilons))
            ]
        )
        new_model_samples[0][0] = base_point.detach()
        return new_model_samples

    def train_helper_model(
        self,
        samples: Tensor,
        samples_value: Tensor,
        num_of_minibatch: int,
        batch_size: int,
        exploration_size: int,
        epochs: int,
        new_samples_count: int,
    ):
        for i, (grad_network, optimizer, eps, quantile) in enumerate(
            zip(self.gradient_networks, self.gradient_optimizers, self.epsilons, self.train_quantiles)
        ):
            grad_network.train()
            relevant_database_mask = torch.arange(0, i + 1)
            relevant_samples = samples[relevant_database_mask].detach()
            new_samples = relevant_samples[-exploration_size:].reshape(-1, samples.shape[-1])
            relevant_samples = relevant_samples.reshape(-1, samples.shape[-1])
            relevant_evaluations = samples_value[relevant_database_mask].detach().reshape((-1))
            mapped_evaluations = self.output_mapping.map(relevant_evaluations)

            dataset = NewPairEpsDataset(
                database=relevant_samples,
                values=mapped_evaluations,
                epsilon=eps,
                new_samples=new_samples,
            )
            taylor_loss = GradientLoss(  # TODO why not using calc_loss?
                self.grad_network,
                self.perturb * self.epsilon,
                functools.partial(
                    loss_with_quantile,
                    quantile=quantile,
                    weights_creator=self.weights_creator,
                    loss=self.grad_loss,
                ),
            )

            self.logger.info(f"Created dataset with {len(dataset)} for eps {eps}")
            for _ in range(epochs):
                train_gradient_network(
                    taylor_loss,
                    self.grad_optimizer,
                    dataset,
                    self.max_batch_size,
                    self.logger,
                )
            grad_network.eval()

    def train_model(self):
        self.model_to_train.train()
        self.model_to_train_optimizer.zero_grad()

        for gradient_net, optimizer in zip(self.gradient_networks, self.gradient_optimizers):
            optimizer.zero_grad()
            gradient_net.eval()

        model_to_train_gradient = torch.stack(
            [
                gradient_net(self.model_to_train.model_parameter_tensor())
                for gradient_net in self.gradient_networks
            ]
        )
        model_to_train_gradient[model_to_train_gradient != model_to_train_gradient] = 0
        magnitudes = torch.norm(model_to_train_gradient, dim=1)
        max_magnitude_index = torch.argmax(magnitudes)
        vector_with_max_magnitude = model_to_train_gradient[max_magnitude_index, :]

        self.logger.info(
            f"Algorithm {self.__class__.__name__} "
            f"moving Gradient size: {torch.norm(vector_with_max_magnitude)} {max_magnitude_index} on {self.env}"
        )

        step_model_with_gradient(
            self.model_to_train,
            vector_with_max_magnitude,
            self.model_to_train_optimizer,
        )
        self.model_to_train.eval()
