import torch
import numpy as np
from tqdm import trange, tqdm

from k_level_policy_gradients.src.utils.serialization import Serializable
from k_level_policy_gradients.src.utils.torch import (
    get_weights,
    set_weights,
)


class TorchApproximator(Serializable):
    """
    Class to interface a pytorch model to the mushroom Regressor interface.
    This class implements all is needed to use a generic pytorch model and train
    it using a specified optimizer and objective function.
    This class supports also minibatches.

    """

    def __init__(
        self,
        input_shape,
        output_shape,
        network,
        optimizer,
        loss=None,
        n_features=None,
        use_cuda=False,
        **network_params
    ):
        """
        Constructor.

        Args:
            input_shape (tuple): shape of the input of the network;
            output_shape (tuple): shape of the output of the network;
            network (torch.nn.Module): the network class to use;
            optimizer (dict): the optimizer used for every fit step;
            loss (torch.nn.functional): the loss function to optimize in the
                fit method;
            n_fit_targets (int, 1): the number of fit targets used by the fit
                method of the network;
            use_cuda (bool, False): if True, runs the network on the GPU;
            **network_params: dictionary of parameters needed to construct the
                network.

        """

        self._use_cuda = use_cuda

        self.network = network(
            input_shape, output_shape, n_features, use_cuda=use_cuda, **network_params
        )

        if self._use_cuda:
            self.network.cuda()

        self._optimizer = optimizer["class"](
            self.network.parameters(), **optimizer["params"]
        )
        self._loss = loss

        self._add_save_attr(
            _use_cuda="primitive",
            network="torch",
            _optimizer="torch",
            _loss="torch",
        )

    def predict(self, *args, output_tensor=False, **kwargs):
        """
        Predict.

        Args:
            *args: input;
            output_tensor (bool, False): whether to return the output as tensor
                or not;
            **kwargs: other parameters used by the predict method
                the regressor.

        Returns:
            The predictions of the model.

        """
        if self._use_cuda:
            torch_args = [
                (
                    torch.from_numpy(x).type(torch.float32).cuda()
                    if isinstance(x, np.ndarray)
                    else x.cuda()
                )
                for x in args
            ]
        else:
            torch_args = [
                (
                    torch.from_numpy(x).type(torch.float32)
                    if isinstance(x, np.ndarray)
                    else x
                )
                for x in args
            ]
        if torch_args[0].ndim == 1:
            torch_args[0] = torch_args[0].unsqueeze(0)  # Make single state 2D
        val = self.network.forward(*torch_args, **kwargs)

        if output_tensor:
            return val
        elif isinstance(val, tuple):
            val = tuple([x.detach().numpy() for x in val])
        else:
            val = val.detach().numpy()

        return val

    def fit(self, *args, **kwargs):
        """
        Fit the model.

        Args:
            *args: input, where the last ``n_fit_targets`` elements
                are considered as the target, while the others are considered
                as input;
            **kwargs: other parameters used by the fit method of the
                regressor.

        """
        if self._use_cuda:
            torch_args = [
                torch.from_numpy(x).cuda() if isinstance(x, np.ndarray) else x.cuda()
                for x in args
            ]
        else:
            torch_args = [
                torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in args
            ]

        x = torch_args[:-1]
        y_hat = self.network(*x, **kwargs)
        y_target = torch_args[-1]

        loss = self._loss(y_hat, y_target)

        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()

        return loss.item()

    def parameters(self):
        """
        Get the parameters of the network.

        Returns:
            The parameters of the network.

        """
        return list(self.network.parameters())

    def gradient(self):
        """
        Get the gradients of the network.

        Returns:
            The gradients of the network.

        """
        return torch.cat([p.grad.detach().flatten() for p in self.parameters])

    def gradient_norm(self):
        """
        Get the norm of the gradients of the network.

        Returns:
            The norm of the gradients of the network.

        """

        total_norm = torch.norm(
            torch.stack([torch.norm(p.grad.detach(), 2) for p in self.parameters()]),
            2.0,
        ).item()

        return total_norm

    def set_weights(self, weights):
        """
        Setter.

        Args:
            w (np.ndarray): the set of weights to set.

        """
        set_weights(self.network.parameters(), weights, self._use_cuda)

    def get_weights(self):
        """
        Getter.

        Returns:
            The set of weights of the approximator.

        """
        return get_weights(self.network.parameters())

    def set_primary_approximator(self, primary_approximator):
        """
        Setter.

        Args:
            primary_network (TorchApproximator): the primary network; take this network's
            weights.

        """
        self.network = primary_approximator.network
        self._optimizer = primary_approximator._optimizer

    @property
    def use_cuda(self):
        return self._use_cuda
