from typing import Iterator, Any
from functools import partial

from omegaconf import DictConfig
import torch
from torch import nn
from torch_misc.modules import Sequential, Normalizer

from flow import Flow, Sequential as SequentialFlow, inv_flow
from flow.transformer import RQ_Spline as Spline, Affine as AffineFlow
from flow import modules as flow_modules
from flow.prior import Normal as NormalPrior


ACT_DICT = {
    'sigmoid': nn.Sigmoid,
    'tanh': nn.Tanh,
    'relu': nn.ReLU,
    'elu': nn.ELU,
    'leaky_relu': nn.LeakyReLU,
}


def nblocks(*steps, N: int) -> Iterator[Any]:
    for _ in range(N):
        for step in steps:
            yield step


def linear(input_dim: int, output_dim: int, init=None, cfg: DictConfig = ...):
    net = nn.Linear(input_dim, output_dim)

    if init is not None:
        net.bias.data = init

    return net


def ff(input_dim: int, output_dim: int, init=None, cfg: DictConfig = ...):
    act = ACT_DICT[cfg.activation]

    net = [Normalizer(input_dim)]
    dims = (input_dim,) + tuple(cfg.hidden_dim)
    for i, o in zip(dims[:-1], dims[1:]):
        net += [
            nn.Linear(i, o),
            act(),
        ]
        
        if cfg.dropout > 0: 
            net.append(nn.Dropout(cfg.dropout))
    
    net.append(nn.Linear(o, output_dim))  # final layer
    net = Sequential(*net)

    if init is not None:
        net[-1].bias.data = init

    return net


class Softclip(Flow):
    """Softclip Flow."""

    def __init__(self, B=5., eps=1e-6, **kwargs):
        r"""
        Args:
            B (float): interval radius to clip: (-B, B).
        """
        super().__init__(**kwargs)

        self.B = B
        self.eps = eps

    def _log_abs_det(self, x):
        """Return log|det J_T|, where T: x -> u."""
        return (
            -2 * torch.log(1 + torch.abs(x) / self.B * (1 - self.eps))
        ).sum(dim=1)

    # Override methods
    def _transform(self, x, log_abs_det=False, **kwargs):
        u = x / (1 + torch.abs(x) / self.B * (1 - self.eps))

        if log_abs_det:
            return u, self._log_abs_det(x)
        else:
            return u

    def _invert(self, u, log_abs_det=False, **kwargs):
        x = u / (1 - (torch.abs(u) / self.B) * (1 - self.eps))

        if log_abs_det:
            return x, -self._log_abs_det(x)
        else:
            return x


def flow(dim: int, cond_dim: int = 0, cfg: DictConfig = ...):
    return SequentialFlow(
        AffineFlow,
        *nblocks(
            partial(Spline, K=cfg.K, A=-cfg.interval, B=cfg.interval),
            AffineFlow,
            N=cfg.blocks
        ),
        dim=dim,
        prior=NormalPrior
    )


def flow01(dim: int, cond_dim: int = 0, cfg: DictConfig = ...):
    return SequentialFlow(
        inv_flow(flow_modules.Sigmoid),
        *nblocks(
            partial(Spline, K=cfg.K, A=-cfg.interval, B=cfg.interval),
            AffineFlow,
            N=cfg.blocks
        ),
        dim=dim,
        prior=NormalPrior
    )


def flow_pos(dim: int, cond_dim: int = 0, cfg: DictConfig = ...):
    return SequentialFlow(
        inv_flow(flow_modules.Softplus),
        *nblocks(
            partial(Spline, K=cfg.K, A=-cfg.interval, B=cfg.interval),
            AffineFlow,
            N=cfg.blocks
        ),
        dim=dim,
        prior=NormalPrior
    )