from typing import NamedTuple, Optional, Union, Tuple, Sequence
import torch

from dataclasses import dataclass
from typing import List
import random

# TODO Horrible buffer to factorize


class Transition(NamedTuple):
    state: torch.Tensor
    action: torch.Tensor
    reward: float
    next_state: torch.Tensor
    done: bool
    log_prob: Optional[torch.Tensor] = None
    value: Optional[
        torch.Tensor
    ] = None  # Used only if the algorithm is an Actor Critic
    entropy: Optional[torch.Tensor] = None
    hidden_state: Optional[
        Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
    ] = None


class BatchTransition(NamedTuple):
    state: torch.Tensor
    action: torch.Tensor
    reward: Sequence[float]
    next_state: torch.Tensor
    done: Sequence[bool]
    log_prob: Optional[torch.Tensor] = None
    value: Optional[
        torch.Tensor
    ] = None  # Used only if the algorithm is an Actor Critic
    entropy: Optional[torch.Tensor] = None
    hidden_state: Optional[
        Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
    ] = None


@dataclass
class BatchOffPolicy:
    transitions: List[Transition]
    device: torch.device

    def reward(self):
        reward = [transition.reward for transition in self.transitions]
        return torch.tensor(reward, dtype=torch.float32, device=self.device).view(
            -1, 1
        )  # noqa: E501

    def state(self):
        return torch.stack([transition.state for transition in self.transitions])

    def next_state(self):
        return torch.stack([transition.next_state for transition in self.transitions])

    def action(self):
        action = [transition.action for transition in self.transitions]
        action_tensor: torch.Tensor = torch.stack(action)
        return action_tensor

    def done(self):
        done: List[bool] = [transition.done for transition in self.transitions]
        done_tensor: torch.Tensor = torch.tensor(
            done, dtype=torch.float32, device=self.device
        ).view(-1, 1)
        return done_tensor


class OffPolicyMemory:
    def __init__(self, memory_size: int, device: torch.device):
        self.memory_size = memory_size
        self.device = device
        self.master_memory: List[Transition] = []

    def append(self, batch_transition: BatchTransition):
        if len(self.master_memory) >= self.memory_size:
            self.master_memory.pop(0)

        list_transition = self._split_batch_transition(batch_transition)
        for transition in list_transition:
            self.master_memory.append(transition)

    def __add__(self, other):
        sum_memory = self.master_memory + other.master_memory
        if len(sum_memory) >= self.memory_size:
            sum_memory = sum_memory[-self.memory_size :]  # noqa: E203
        new_memory = OffPolicyMemory(memory_size=self.memory_size, device=self.device)
        new_memory.master_memory = sum_memory
        return new_memory

    def __iadd__(self, other):
        self.master_memory += other.master_memory
        if len(self.master_memory) >= self.memory_size:
            self.master_memory = self.master_memory[-self.memory_size :]  # noqa: E203
        return self

    def __len__(self) -> int:
        return len(self.master_memory)

    def sample(self, nb_sample: int):
        nb_sample = nb_sample if nb_sample <= self.__len__() else self.__len__()
        batch = random.sample(self.master_memory, nb_sample)
        return BatchOffPolicy(batch, device=self.device)

    def _split_batch_transition(
        self, batch_transition: BatchTransition
    ) -> List[Transition]:  # noqa:
        list_state: torch.Tensor = batch_transition.state
        list_action: torch.Tensor = batch_transition.action
        list_reward: Sequence[float] = batch_transition.reward
        list_next_state: torch.Tensor = batch_transition.next_state
        list_done: Sequence[bool] = batch_transition.done

        list_state = list(list_state.unbind(dim=0))

        list_next_state = list(list_next_state.unbind(dim=0))

        iterator = zip(
            list_state,
            list_action,
            list_reward,
            list_next_state,
            list_done,
        )
        list_transition: List[Transition] = []
        for state, action, reward, next_state, done in iterator:
            transition_unpack = Transition(
                state=state,
                action=action,
                reward=reward,
                next_state=next_state,
                done=done,
            )
            list_transition.append(transition_unpack)
        return list_transition
