import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import stats
from abc import ABC, abstractmethod

from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple

# def tensor(arr: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor:
#     return torch.tensor(arr, device='cpu', dtype=dtype)

def make_mlp(
    hidden_sizes: List[int],
    in_dim: int = 1,
    out_dim: int = 1,
    activator: Callable[[], nn.Module] = lambda : nn.SELU(inplace=True),
) -> nn.Sequential:
    modules = []
    dims = [in_dim] + hidden_sizes + [out_dim]
    for i in range(len(dims) - 1):
        linear = nn.Linear(dims[i], dims[i + 1])
        # TruncatedNormal is use in author's code, but this also works well
        nn.init.orthogonal_(linear.weight)
        nn.init.zeros_(linear.bias)
        modules.append(linear)
        if i < len(hidden_sizes):
            modules.append(activator())
    return nn.Sequential(*modules)

def make_mlps(
    hidden_sizes: List[int],
    in_dim: int,
    out_dim: int,
    num_ensembles: int,
    activator: Callable[[], nn.Module],
) -> nn.ModuleList:
    return nn.ModuleList(
        [
            make_mlp(hidden_sizes, in_dim=in_dim, out_dim=out_dim, activator=activator)
            for _ in range(num_ensembles)
        ]
    )
    
class EnsembleMLP(nn.Module):
    def __init__(
        self,
        hidden_sizes: List[nn.Module],
        num_ensembles: int,
        in_dim: int = 1,
        out_dim: int = 1,
        activator: Callable[[], nn.Module] = lambda: nn.SELU(inplace=True),
    ) -> None:
        super().__init__()
        self.models = make_mlps(hidden_sizes, in_dim, out_dim, num_ensembles, activator)
        self.num_ensembles = num_ensembles
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = [m(x) for m in self.models]
        return torch.cat(out, dim=1)
    
class QNetwork(ABC):
    @property
    @abstractmethod
    def num_actions(self) -> int:
        pass
    
    @abstractmethod
    def q_s(self, obs: torch.Tensor) -> torch.Tensor:
        pass
    
    def q_s_a(self, obs: torch.Tensor, action: np.ndarray) -> torch.Tensor:
        batch_size = obs.size(0)
        return self.q_s(obs)[np.arange(batch_size), action]
    
    def q_target(
        self,
        next_obs: torch.Tensor,
        reward: torch.Tensor,
        pcont: torch.Tensor,
        discount: float,
    ) -> torch.Tensor:
        q_next, _ = self.q_s(next_obs).max(axis=-1)
        return reward + q_next.mul_(pcont).mul_(discount)
    
    def reset(self) -> None:
        pass
    
    def greedy(self, state: torch.Tensor) -> torch.Tensor:
        return self.q_s(state).squeeze().argmax().item()
    
    # @torch.no_grad()
    # def eps_greedy(self, epsilon: float, step: TimeStep) -> int:
    #     if epsilon > 0 and np.random.rand() < epsilon:
    #         return np.random.randint(self.num_actions)
    #     else:
    #         return self.greedy(tensor(step.observation).view(1, -1))
        
class EnsembleQNetwork(nn.Module, QNetwork):
    def __init__(
        self,
        hidden_size: List[int],
        state_dim: int,
        num_actions: int,
        num_ensembles: int = 10,
        activator: Callable[[], nn.Module] = lambda: nn.ReLU(inplace=True),
    ) -> None:
        super().__init__()
        self.models = make_mlps(
            hidden_size, state_dim, num_actions, num_ensembles, activator
        )
        self.active_head = 0
        self.num_ensembles = num_ensembles
        self._num_actions = num_actions
        
    @property
    def num_actions(self) -> int:
        return self._num_actions

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.stack([m(x) for m in self.models], dim=1)

    def q_s(self, state: torch.Tensor) -> torch.Tensor:  # [B, K, A]
        return self(state)

    def q_s_a(self, state: torch.Tensor, action: np.ndarray) -> torch.Tensor:
        batch_size = state.size(0)
        # action = tensor(action, dtype=torch.long)
        one_hot = F.one_hot(action.squeeze(-1), num_classes=self._num_actions)
        qs = self(state)
        # print(qs.shape, one_hot.shape)
        return torch.einsum("bka,ba->bk", qs, one_hot.float())
    
    def q_target(
        self,
        next_obs: torch.Tensor,
        reward: torch.Tensor,
        pcont: torch.Tensor,
        discount: float,
    ) -> torch.Tensor:
        q_next, _ = self.q_s(next_obs).max(axis=-1)
        return reward.view(-1, 1) + q_next * pcont.view(-1, 1) * discount

    def reset(self):
        self.active_head = np.random.randint(self.num_ensembles)

    def greedy(self, state: torch.Tensor) -> torch.Tensor:
        return self(state)[:, self.active_head, :].squeeze().argmax().item()
    
class PriorEnsembleQNetwork(EnsembleQNetwork):
    def __init__(
        self,
        hidden_size: List[int],
        state_dim: int,
        num_actions: int,
        num_ensembles: int = 10,
        prior_scale: float = 10.0,
        activator: Callable[[], nn.Module] = lambda: nn.ReLU(inplace=True),
    ) -> None:
        nn.Module.__init__(self)
        self.model = make_mlps(
            hidden_size, state_dim, num_actions, num_ensembles, activator
        )
        self.prior = make_mlps(
            hidden_size, state_dim, num_actions, num_ensembles, activator
        )
        self._num_actions = num_actions
        self.prior_scale = prior_scale
        self.active_head = 0
        self.num_ensembles = num_ensembles
        
    def parameters(self) -> Iterable[nn.Module]:
        return self.model.parameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        raw = torch.stack([m(x) for m in self.model], dim=1)
        with torch.no_grad():
            prior = torch.stack([m(x) for m in self.prior], dim=1)
        return raw + prior * self.prior_scale


if __name__ == '__main__':
    num_ensembles = 5
    net = PriorEnsembleQNetwork([32, 32], 2, 4, num_ensembles)
    
    states = torch.randn(10, 10, 2)
    actions = torch.randint(0, 4, (5,))
    
    q_s = net(states)
    print(q_s.shape)
    # print(actions)
    # print(net.active_head)
    
    # print(net(states).shape)
    # print(net.q_s_a(states, actions).shape)
    
    