import abc
import copy
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import nn

from lambda_ac.nn.common import hard_update


class Distribution(abc.ABC):
    @abc.abstractmethod
    def rsample(self, n: int = 1) -> torch.Tensor:
        pass

    @abc.abstractmethod
    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        pass

    @property
    @abc.abstractmethod
    def mean(self) -> torch.Tensor:
        pass

    @property
    @abc.abstractmethod
    def std(self) -> torch.Tensor:
        pass

    @property
    @abc.abstractmethod
    def entropy(self) -> torch.Tensor:
        pass


@dataclass
class EncoderOutput:
    encoded: Distribution
    hidden: Optional[torch.Tensor]


class FeatureInput:
    encoded: torch.Tensor
    hidden: torch.Tensor

    def __init__(
        self, encoded: Union[Distribution, torch.Tensor], hidden: Optional[torch.Tensor]
    ):
        if isinstance(encoded, Distribution):
            self.encoded = encoded.rsample()
        else:
            self.encoded = encoded
        if hidden is None:
            self.hidden = torch.randn_like(self.encoded)
        else:
            self.hidden = hidden

    @staticmethod
    def from_output(input: EncoderOutput) -> "FeatureInput":
        return FeatureInput(input.encoded, input.hidden)

    @staticmethod
    def from_tensors(
        encoded: torch.Tensor, hidden: Optional[torch.Tensor]
    ) -> "FeatureInput":
        if hidden is None:
            hidden = torch.randn_like(encoded) * 0.1
        return FeatureInput(encoded, hidden)

    @staticmethod
    def from_list(inputs: Iterable["FeatureInput"], dim=1) -> "FeatureInput":
        return FeatureInput(
            torch.stack([f.encoded for f in inputs], dim=dim),
            torch.stack([f.hidden for f in inputs], dim=dim),
        )

    def to(self, device) -> "FeatureInput":
        return FeatureInput.from_tensors(
            self.encoded.to(device), self.hidden.to(device)
        )

    def detach(self) -> "FeatureInput":
        self.encoded = self.encoded.detach()
        self.hidden = self.hidden.detach()
        return self

    def __getitem__(self, idx):
        return FeatureInput(self.encoded[idx], self.hidden[idx])

    def __setitem__(self, idx, value):
        self.encoded[idx] = value.encoded
        self.hidden[idx] = value.hidden

    def repeat(self, *n):
        return FeatureInput.from_tensors(
            self.encoded.repeat(*n), self.hidden.repeat(*n)
        )


@dataclass
class ModelTrajectory:
    length: int
    states: FeatureInput
    actions: torch.Tensor
    rewards: torch.Tensor
    masks: torch.Tensor
    log_probs: torch.Tensor
    encodings: Optional[List[EncoderOutput]]
    done_predictions: Optional[torch.Tensor]


class Encoder(nn.Module):
    input_dim: Union[int, Tuple[int, ...]]
    feature_dim: int
    normalize: bool
    norm: nn.Module


class EncoderMixin:
    encoder: Encoder
    feature_dim: int

    def init_encoder(self, encoder: Encoder):
        self.encoder = encoder

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        if self.encoder is None:
            raise ValueError("Encoder is not initialized.")
        return self.encoder(x)


class CriticModule(nn.Module):
    pass


class QHead(nn.Module):
    @abc.abstractmethod
    def forward(
        self, state: FeatureInput, action: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        pass


class EncoderCriticModule(EncoderMixin, CriticModule):
    head: QHead

    def __init__(self, encoder) -> None:
        super().__init__()
        self.init_encoder(encoder)


class ActorModule(nn.Module):
    action_dim: int

    def sample(self, distribution: Distribution, n: int = 1) -> torch.Tensor:
        return distribution.rsample(n)

    def log_prob(
        self, action: torch.Tensor, distribution: Distribution
    ) -> torch.Tensor:
        return distribution.log_prob(action)

    def forward_sample_log_prob(
        self, *args, **kwargs
    ) -> Tuple[torch.Tensor, torch.Tensor, Distribution]:
        action_distribution = self(*args, **kwargs)
        action = self.sample(action_distribution)
        log_prob = self.log_prob(action, action_distribution)
        return action, log_prob, action_distribution

    @abc.abstractmethod
    def forward(self, *args, **kwargs) -> Distribution:
        raise NotImplementedError


class EncoderActorModule(EncoderMixin, ActorModule):
    head: ActorModule

    def __init__(self, encoder) -> None:
        super().__init__()
        self.init_encoder(encoder)


class ExplorationScheduler(abc.ABC):
    @abc.abstractmethod
    def step(self):
        pass

    @abc.abstractmethod
    def set(self, n: int):
        pass

    @abc.abstractmethod
    def __call__(self, action: torch.Tensor) -> torch.Tensor:
        pass


class ModelNetwork(nn.Module):
    @abc.abstractmethod
    def forward(
        self, *args, **kwargs
    ) -> Tuple[EncoderOutput, torch.Tensor, torch.Tensor]:
        pass

    @abc.abstractmethod
    def sample(
        self, *args, **kwargs
    ) -> Tuple[EncoderOutput, torch.Tensor, torch.Tensor]:
        pass


class EncoderModelNetwork(EncoderMixin, ModelNetwork):
    head: nn.Module
    latent_network: nn.Module

    def __init__(self, encoder) -> None:
        super().__init__()
        self.init_encoder(encoder)


class Agent(abc.ABC):
    @abc.abstractmethod
    def select_action(self, state: torch.Tensor, eval: bool = False) -> torch.Tensor:
        pass


class ActorCriticAgent(Agent, abc.ABC):
    critic: CriticModule
    critic_target: CriticModule
    actor: ActorModule

    def __init__(
        self, critic: CriticModule, critic_target: CriticModule, actor: ActorModule
    ):
        self.critic = critic
        self.critic_target = critic_target
        hard_update(self.critic_target, self.critic)
        self.actor = actor

        self.critic_target.requires_grad_(False)

    @abc.abstractmethod
    def select_action(
        self, state: Union[torch.Tensor, FeatureInput], eval: bool = False
    ) -> torch.Tensor:
        pass

    @abc.abstractmethod
    def update_critic(self, *args, **kwargs) -> Dict[str, float]:
        pass

    @abc.abstractmethod
    def update_actor(self, *args, **kwargs) -> Dict[str, float]:
        pass

    @abc.abstractmethod
    def load_checkpoint(self, path):
        pass

    @abc.abstractmethod
    def save_checkpoint(self, path):
        pass


class ModelBasedActorCriticAgent(ActorCriticAgent):
    model: ModelNetwork

    @abc.abstractmethod
    def __init__(
        self,
        critic: CriticModule,
        critic_target: CriticModule,
        actor: ActorModule,
        model: ModelNetwork,
    ):
        super().__init__(critic, critic_target, actor)
        self.model = model

    @abc.abstractmethod
    def update_model(self, *args, **kwargs) -> Dict[str, float]:
        pass


class PlanningStrategy(metaclass=abc.ABCMeta):
    def __init__(
        self,
        model: EncoderModelNetwork,
        critic: EncoderCriticModule,
        actor: EncoderActorModule,
    ) -> None:
        self.model = model
        self.critic = critic
        self.actor = actor

    def set_networks(self, actor, critic, model):
        self.actor = actor
        self.critic = critic
        self.model = model

    @abc.abstractmethod
    def plan(
        self,
        state: torch.Tensor,
        alpha: torch.Tensor,
        eval: bool,
        step: int,
        episode: int,
    ) -> torch.Tensor:
        pass


class EncoderModelBasedActorCriticAgent(ActorCriticAgent):

    critic: EncoderCriticModule
    critic_target: EncoderCriticModule
    actor: EncoderActorModule
    model: EncoderModelNetwork
    planning_strategy: PlanningStrategy

    @abc.abstractmethod
    def __init__(
        self,
        critic: EncoderCriticModule,
        critic_target: EncoderCriticModule,
        actor: EncoderActorModule,
        model: EncoderModelNetwork,
    ):
        super().__init__(critic, critic_target, actor)
        self.model = model

    @abc.abstractmethod
    def update_model(self, *args, **kwargs) -> Dict[str, float]:
        pass
