from typing import Tuple, Optional

import gym
import torch
import numpy as np
from torch import nn, optim, Tensor
from torch.distributions.beta import Beta

from .config import PiConfig


def to_tnsr(array: np.ndarray, device: str) -> Tensor:
    return torch.tensor(array, dtype=torch.float32, device=device)


def make_net_opt(env: gym.Env, config: PiConfig) -> Tuple[nn.Module, optim.Optimizer]:
    if len(env.observation_space.shape) == 1:
        _net = fc_net
    else:
        raise NotImplementedError
    net = _net(env, config).to(config.device)
    _opt = getattr(torch.optim, config.optimizer)
    opt = _opt(net.parameters(), lr=config.lr)
    return net, opt


def fc_net(env: gym.Env, config: PiConfig) -> nn.Module:
    depth, hidden = config.depth, config.hidden
    act_layer = getattr(nn, config.activation)
    modules = [
        nn.Linear(env.observation_space.shape[0], hidden),
    ]
    for _ in range(depth - 1):
        modules += [act_layer(), nn.Linear(hidden, hidden)]
    modules.append(act_layer())
    return nn.Sequential(*modules)


class QNet(nn.Module):
    def __init__(self, env: gym.Env, config: PiConfig):
        super().__init__()
        self.fc = fc_net(env, config)
        self.out = nn.Linear(config.hidden, 1)

    def forward(self, obs: Tensor, act: Optional[Tensor] = None) -> Tensor:
        x = self.fc(torch.cat([obs, act], -1))
        return self.out(x)


class BetaPolicyNet(nn.Module):
    def __init__(self, env: gym.Env, config: PiConfig):
        super().__init__()
        self.device = config.device
        self.fc = fc_net(env, config)
        act_dim = env.action_space.shape[0]
        self.high = to_tnsr(env.action_space.high[0], device=self.device)
        self.low = to_tnsr(env.action_space.low[0], device=self.device)
        self.to_log_alpha = nn.Linear(config.hidden, act_dim)
        self.to_log_beta = nn.Linear(config.hidden, act_dim)
        self.config = config

    def forward(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
        pol_dist = self.calc_pol_dist(obs)
        raw_act = pol_dist.rsample()

        logp_pol = self.calc_log_pol(pol_dist, raw_act)
        pol_act = self.squash_act(raw_act)
        return pol_act, logp_pol

    def squash_act(self, raw_act: Tensor) -> Tensor:
        """Squash raw_act of [-1, 1] to [low, high]

        Args:
            raw_act (Tensor): vector of range [-1, 1]

        Returns:
            vector of range [low, high]
        """
        return raw_act * (self.high - self.low) + self.low

    def unsquash_act(self, act: Tensor) -> Tensor:
        """Unsquash act of [low, high] to [-1, 1]

        Args:
            act (Tensor): vector of range [low, high]

        Returns:
            vector of range [-1, 1]
        """
        return (act - self.low) / (self.high - self.low)

    def calc_pol_dist(self, obs: Tensor) -> Tensor:
        fc_out = self.fc(obs)
        alpha = self.to_log_alpha(fc_out).exp()
        beta = self.to_log_beta(fc_out).exp()
        return Beta(alpha, beta)

    def calc_log_pol(self, dist: Beta, pol_act: Tensor) -> Tensor:
        log_pol = dist.log_prob(pol_act).sum(axis=-1)
        log_pol -= (self.high - self.low).log()
        return log_pol
