import torch
import torch.nn as nn


class ConvNet(nn.Module):
    """
    Perform 1D conv where each channel corresponds to each output dimension.
    """

    def __init__(self,
                 obs_dim: int,
                 latent_dim: int,
                 conv1_dim: int,
                 bundling_k: int,
                 out_dim: None,
                 act: str,
                 n_channel: int = None):
        super().__init__()
        conv1_outdim = latent_dim - conv1_dim + 1
        conv2_dim = conv1_outdim - bundling_k + 1  # adjust outdim so that it becomes K
        n_channel = obs_dim if n_channel is None else n_channel
        out_dim = out_dim if out_dim is not None else obs_dim

        self.model = nn.Sequential(
            nn.Conv1d(obs_dim, n_channel, conv1_dim),
            getattr(nn, act)(),
            nn.Conv1d(n_channel, out_dim, conv2_dim)
        )

    def forward(self, xs):
        """
        xs: [bs, obs_dim, latent]
        """
        ds = self.model(xs).transpose(2, 1)  # [bs, bundling k = (time), obs_dim]
        return ds


class ConsistencyDec(nn.Module):

    def __init__(self, dt: float):
        super().__init__()
        self.dt = dt

    def forward(self, u0: torch.tensor, ds: torch.tensor):
        """
        u0: [batch x dim]
        ds: [batch x time x dim]
        """
        u0 = u0.unsqueeze(dim=1)
        ts = torch.arange(1, ds.shape[1] + 1,
                          device=u0.device).view(1, -1, 1) * self.dt
        pred = u0 + ts * ds
        return pred


class ConsistencyDec2D(nn.Module):

    def __init__(self, dt: float):
        super().__init__()
        self.dt = dt

    def forward(self, u0: torch.tensor, ds: torch.tensor):
        """
        u0: [batch x dim]
        ds: [batch x time x channel x height x width]
        """
        u0 = u0.unsqueeze(dim=1)
        ts = torch.arange(1, ds.shape[1] + 1,
                          device=u0.device).view(1, -1, 1, 1, 1) * self.dt
        pred = u0 + ts * ds
        return pred
