from typing import Callable, List, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions.kl import kl_divergence

from lambda_ac.nn.common import MLP
from lambda_ac.rl_types import (
    EncoderActorModule,
    EncoderCriticModule,
    EncoderModelNetwork,
    EncoderOutput,
    FeatureInput,
)


class ModelLossModule(nn.Module):
    def __init__(
        self,
        model: EncoderModelNetwork,
        actor: EncoderActorModule,
        critic: EncoderCriticModule,
    ):
        super().__init__()
        self.model = model
        self.actor = actor
        self.critic = critic

        self.decoder = MLP(
            input_dim=self.model.feature_dim,
            output_dim=self.model.encoder.input_dim,
            hidden_dim=self.model.encoder.hidden_dim,
            hidden_layers=self.model.encoder.num_hidden_layers,
        )

        self.predictor = torch.nn.Linear(model.feature_dim, model.feature_dim)

        self.parameter_list = list(self.decoder.parameters()) + list(
            self.predictor.parameters()
        )

        self.loss_list: List[
            Callable[
                [EncoderOutput, EncoderOutput, torch.Tensor],
                torch.Tensor,
            ]
        ] = [self._placeholder]
        self.weight_list: List[float] = [0.0]

        self.running_vaml_mean = 1.0

    def __call__(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        assert len(self.loss_list) > 0, "No loss function is registered"
        losses = []
        total_loss = torch.zeros_like(next_state[:, :1])
        for loss_fn, weight in zip(self.loss_list, self.weight_list):
            loss = loss_fn(
                model_prediction,
                next_feature_encoding,
                next_state,
            )
            losses.append(loss)
            total_loss += loss * weight
        return total_loss

    def register_loss(self, loss_type: str, weight: float = 1.0):
        if loss_type == "mse":
            self.loss_list.append(self._mse_loss)
        elif loss_type == "kl":
            self.loss_list.append(self._kl_loss)
        elif loss_type == "reconstruction":
            self.loss_list.append(self._reconstruction_loss)
        elif loss_type == "vaml_weighted_mse":
            self.loss_list.append(self._vaml_weighted_mse_loss)
        elif loss_type == "normalized_vaml":
            self.loss_list.append(self._normalized_vaml_loss)
        elif loss_type == "huber_vaml":
            self.loss_list.append(self._huber_vaml_loss)
        elif loss_type == "vaml":
            self.loss_list.append(self._vaml_loss)
        else:
            raise ValueError("Unknown loss type: {}".format(loss_type))
        self.weight_list.append(weight)

    def _get_vaml_objects(
        self, model_encoding: EncoderOutput, next_feature_encoding: EncoderOutput
    ) -> Tuple[torch.Tensor, ...]:
        next_feature = FeatureInput.from_tensors(
            next_feature_encoding.encoded.mean,
            next_feature_encoding.hidden,
        )
        next_feature.detach()
        model_next_feature = FeatureInput.from_tensors(
            model_encoding.encoded.mean, model_encoding.hidden
        )

        with torch.no_grad():
            next_action, _, state_action_dist = self.actor.head.forward_sample_log_prob(
                next_feature
            )
            qf1, qf2 = self.critic.head(next_feature, next_action)

        model_qf1, model_qf2 = self.critic.head(model_next_feature, next_action)
        _, _, model_action_dist = self.actor.head.forward_sample_log_prob(
            model_next_feature
        )
        return (
            model_qf1,
            model_qf2,
            model_action_dist.mean,
            qf1,
            qf2,
            state_action_dist.mean,
        )

    def _normalized_vaml_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ):
        (
            model_qf1,
            model_qf2,
            model_action,
            qf1,
            qf2,
            state_action,
        ) = self._get_vaml_objects(model_prediction, next_feature_encoding)
        value_error_1 = torch.mean(
            ((model_qf1 - qf1) / torch.mean(qf1)) ** 2, dim=-1, keepdim=True
        )
        value_error_2 = torch.mean(
            ((model_qf2 - qf2) / torch.mean(qf2)) ** 2, dim=-1, keepdim=True
        )
        value_error = 0.5 * (value_error_1 + value_error_2)
        action_error = torch.mean(
            (model_action - state_action) ** 2, dim=-1, keepdim=True
        )
        return value_error + action_error

    def _vaml_weighted_mse_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ):
        with torch.no_grad():
            (
                model_qf1,
                model_qf2,
                model_action,
                qf1,
                qf2,
                state_action,
            ) = self._get_vaml_objects(model_prediction, next_feature_encoding)
            value_error_1 = F.mse_loss(model_qf1, qf1, reduction="none")
            value_error_2 = F.mse_loss(model_qf2, qf2, reduction="none")
            value_error = value_error_1 + value_error_2
            self.running_vaml_mean = (
                0.995 * self.running_vaml_mean + 0.005 * value_error.mean()
            )
            value_error = (value_error / self.running_vaml_mean).detach()
        next_distribution = next_feature_encoding.encoded
        model_next_distribution = model_prediction.encoded
        mse_error = F.mse_loss(
            model_next_distribution.mean, next_distribution.mean, reduction="none"
        ).mean(dim=-1, keepdim=True)
        return value_error * mse_error

    def _huber_vaml_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        (
            model_qf1,
            model_qf2,
            model_action,
            qf1,
            qf2,
            state_action,
        ) = self._get_vaml_objects(model_prediction, next_feature_encoding)
        value_error1 = F.huber_loss(model_qf1, qf1, reduction="none").mean(
            dim=-1, keepdim=True
        )
        value_error2 = F.huber_loss(model_qf2, qf2, reduction="none").mean(
            dim=-1, keepdim=True
        )
        value_error = 0.5 * (value_error1 + value_error2)

        action_error = torch.mean(
            (model_action - state_action) ** 2, dim=-1, keepdim=True
        )
        return value_error + action_error

    def _vaml_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        (
            model_qf1,
            model_qf2,
            model_action,
            qf1,
            qf2,
            state_action,
        ) = self._get_vaml_objects(model_prediction, next_feature_encoding)
        value_error_1 = torch.mean((model_qf1 - qf1) ** 2, dim=-1, keepdim=True)
        value_error_2 = torch.mean((model_qf2 - qf2) ** 2, dim=-1, keepdim=True)
        value_error = 0.5 * (value_error_1 + value_error_2)
        action_error = torch.mean(
            (model_action - state_action) ** 2, dim=-1, keepdim=True
        )
        return value_error + action_error

    def _reconstruction_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        next_feature = next_feature_encoding.encoded.rsample()
        model_next_feature = model_prediction.encoded.rsample()
        next_distribution = next_feature_encoding.encoded
        model_next_distribution = model_prediction.encoded
        return torch.mean(
            (self.decoder(model_next_feature) - next_state) ** 2,
            dim=-1,
            keepdim=True,
        )

    def _mse_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        next_feature = next_feature_encoding.encoded.rsample()
        model_next_feature = model_prediction.encoded.rsample()
        next_distribution = next_feature_encoding.encoded
        model_next_distribution = model_prediction.encoded
        return torch.mean(
            (model_next_distribution.mean - next_distribution.mean) ** 2,
            dim=-1,
            keepdim=True,
        )

    def _kl_loss(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        next_feature = next_feature_encoding.encoded.rsample()
        model_next_feature = model_prediction.encoded.rsample()
        next_distribution = next_feature_encoding.encoded
        model_next_distribution = model_prediction.encoded
        kl_estimate = kl_divergence(
            next_distribution.get_pdf(),
            model_next_distribution.get_pdf(),
        )
        # print(kl_estimate.mean())
        return kl_estimate.mean(dim=-1, keepdim=True)

    def _placeholder(
        self,
        model_prediction: EncoderOutput,
        next_feature_encoding: EncoderOutput,
        next_state: torch.Tensor,
    ) -> torch.Tensor:
        next_feature = next_feature_encoding.encoded.rsample()
        model_next_feature = model_prediction.encoded.rsample()
        next_distribution = next_feature_encoding.encoded
        model_next_distribution = model_prediction.encoded
        return torch.zeros_like(model_next_feature[..., :1])
