"""
Simple operations expressed as PyTorch modules.
"""

from collections import namedtuple

import torch
import torch.nn as nn
from typing import List, Tuple

class NNLayerScale(torch.nn.Module):
    """
    Scales the values pair-wise with the provided scalars.
    """
    def __init__(self, scalars, dtype=torch.float32):
        super().__init__()
        if not isinstance(scalars, torch.Tensor):
            scalars = torch.tensor(self.scalars, dtype=dtype)
        self.register_buffer("scalars", scalars)

    def extra_repr(self):
        return f"scalars: {self.scalars}"

    def forward(self, x):
        return self.scalars * x


class NNTanhRefit(torch.nn.Module):
    """
    Refits the output of Tanh such that the output is distributed within the
    high and low bounds of the environment.
    """
    def __init__(self, low, high, shape=None, dtype=torch.float32):
        super().__init__()
        low = low
        high = high
        if not isinstance(low, torch.Tensor):
            low = torch.tensor(low, dtype=dtype)
        if not isinstance(high, torch.Tensor):
            high = torch.tensor(high, dtype=dtype)

        # Ensure that bounds have proper shapes
        def reshape(bound):
            if bound.shape != shape:
                if bound.shape == (1,):
                    bound = bound[0]
                if len(bound.shape) == 0:
                    return torch.full(shape, bound)
                else:
                    raise ValueError(f"Dimensionality mismatch between bound ({bound.shape}) and shape ({shape})")
            else:
                return bound

        if shape is not None:
            low = reshape(low)
            high = reshape(high)

        # Unsqueeze the batch dimensions
        low = low.unsqueeze(0)
        high = high.unsqueeze(0)

        self.register_buffer("scale", high - low)
        self.register_buffer("shift", low)

    def extra_repr(self):
        return f"scale: {self.scale}, shift: {self.shift}"

    def forward(self, x):
        # [-1,1] -> [0,1]
        x = (x + 1)/2
        # [0,1] -> [low, high]
        return (x * self.scale) + self.shift

    def undo(self, x):
        # [low, high] -> [0,1]
        x = (x - self.shift) / self.scale
        # [0,1] -> [-1,1]
        return 2*x - 1


class NNLayerRK4(torch.nn.Module):
    """
    An RK4 layer. Takes 2 inputs and computes the output in an RK4 fashion:

        def RK4(p, y):
            k1 = net(p, y)
            k2 = net(p, y + k1/2)
            k3 = net(p, y + k2/2)
            k4 = net(p, y + k3)
            return y + (k1 + 2*k2 + 2*k3 + k4)/6

    net : [N, pDim + yDim] -> [N, yDim]
    """
    def __init__(self, net):
        super().__init__()
        self.net = net

    def forward(self, p, y):
        """
        Dimensions:
          - p: [N, pDim]
          - y: [N, yDim]
          - return: [N, yDim]
        """
        k1 = self.net(torch.cat([p, y], dim=1))
        k2 = self.net(torch.cat([p, y + k1/2], dim=1))
        k3 = self.net(torch.cat([p, y + k2/2], dim=1))
        k4 = self.net(torch.cat([p, y + k3], dim=1))
        return y + (k1 + (2 * k2) + (2 * k3) + k4)/6


class NNLayerConcat2(torch.nn.Module):
    """
    Concatenated two inputs along the specified axis, and applies the result to
    what comes next. Any input that is not 1-dimensional (ignoring batch
    dimension) will be flattened before concatenation. Output will always be a
    2-dimensional tensor of shape Nx(prod(x.shape[1:]) + prod(y.shape[1:]))
    where N is the batch dimension.
    """
    def __init__(self, dim=-1,
                 init_left=torch.nn.Flatten(start_dim=1, end_dim=-1),
                 init_right=torch.nn.Flatten(start_dim=1, end_dim=-1),
                 next=torch.nn.Identity()):
        super().__init__()
        self.dim = dim
        self.next = next
        self.init_left = init_left
        self.init_right = init_right

    def extra_repr(self):
        return f"dim: {self.dim}"

    def forward(self, x, y):
        x = self.init_left(x)
        y = self.init_right(y)
        return self.next(torch.cat([x,y], dim=self.dim))


class NNLayerConcat(torch.nn.Module):
    """
    Same as NNLayerConcat2, but for any number of arguments.
    """
    def __init__(self, dim=-1,
                 init_all=torch.nn.Identity(),
                 next=torch.nn.Identity()):
        super().__init__()
        self.dim = dim
        self.next = next
        self.init_all = init_all

    def extra_repr(self):
        return f"dim: {self.dim}"

    def forward(self, *inputs):
        return self.next(torch.cat(tuple(self.init_all(inp) for inp in inputs), dim=self.dim))


class NNLayerHeadSplit(torch.nn.Module):
    """
    Applies the inputs on multiple heads and returns the result as a tuple
    """
    def __init__(self, **heads):
        super().__init__()
        self.heads = nn.ModuleDict([[k, v] for k,v in heads.items()])

    def forward(self, x : torch.Tensor) -> Tuple[torch.Tensor]:
        T = namedtuple("Heads", list(self.heads.keys()))
        return T(**{k: head(x) for k, head in self.heads.items()})


class NNLayerSqueeze(torch.nn.Module):
    """
    Squeezes the input in the specified dimension.
    """
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def extra_repr(self):
        return f"dim: {self.dim}"

    def forward(self, x):
        return x.squeeze(self.dim)


class NNLayerExp(torch.nn.Module):
    """Exponentiates the input"""
    def forward(self, x):
        return torch.exp(x)


class NNLayerClipExp(torch.nn.Module):
    """Exponentiates the input, but clipping for min and max values"""
    def __init__(self, min=-20, max=2):
        super().__init__()
        self.min = min
        self.max = max

    def extra_repr(self):
        return f"min={self.min}, max={self.max}"

    def forward(self, x):
        return torch.exp(x.clip(self.min, self.max))


class NNLayerSqrt(torch.nn.Module):
    """Exponentiates the input"""
    def forward(self, x):
        return torch.sqrt(x)


class NNLayerClipSiLU(torch.nn.Module):
    """Exponentiates the input, but clipping for min and max values"""
    def __init__(self, lower=-20.0):
        super().__init__()
        self.lower = lower

    def extra_repr(self):
        return f"lower={self.lower}"

    def forward(self, x):
        return torch.nn.functional.silu(x.clip(min=self.lower))
