import torch
import numpy as np
import torch.nn as nn

from einops import rearrange

# encoder_args: # fixed across all models
#   norm_last_layer: true
#   num_layers: 10
#   emb_dim: 320
#   pool_across_time_mode: max


class TSEncoder(nn.Module):  # wrapper includes layer Norm and pooling.
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.norm_bool = config.encoder_args.norm_last_layer
        self.ts_encoder = TSEncoder_(
            config.data_args.input_dims,
            config.encoder_args.emb_dim,
            hidden_dims=config.encoder_args.hidden_dim,
            depth=config.encoder_args.num_layers,
        )

        if self.norm_bool:
            self.context_norm = nn.LayerNorm(config.encoder_args.emb_dim)

        if config.encoder_args.pool_across_time_mode == "max":
            self.pool = nn.AdaptiveMaxPool1d(1)
        elif config.encoder_args.pool_across_time_mode == "avg":
            self.pool = nn.AdaptiveAvgPool1d(1)  # pool across time dimension

    def forward(self, x):
        # x: b, t, c
        embed = self.ts_encoder(x)  # b, t, z

        if self.norm_bool:
            embed = self.context_norm(embed)

        embed_pool = self.pool(rearrange(embed, "b t z -> b z t")).squeeze()
        return embed_pool, embed  # (b z) and (b, t, z)

    # def pool_timevarying(self, context):
    # if self.config.model_args.time_vary_args.include:
    # tv, dtv = self.tv_module(context)
    # else:
    # tv, dtv = (None, None)

    # pooling function


# ========================================= FROM REBAR =========================================


class TSEncoder_(torch.nn.Module):
    def __init__(
        self,
        input_dims,
        output_dims,
        hidden_dims=64,
        depth=10,
    ):
        super().__init__()
        self.input_dims = input_dims
        self.output_dims = output_dims
        self.hidden_dims = hidden_dims
        self.input_fc = torch.nn.Linear(input_dims, hidden_dims)
        self.feature_extractor = DilatedConvEncoder(
            hidden_dims, [hidden_dims] * depth + [output_dims], kernel_size=3
        )

    def forward(self, x, mask=None):  # x: B x T x input_dims
        nan_mask = ~x.isnan().any(
            axis=-1
        )  # this is necessary  bc TS2vec purposely introduces nans that we need to 0 out
        x[~nan_mask] = 0

        x = self.input_fc(x)  # B x T x Ch

        if mask == "binomial":
            mask = torch.from_numpy(
                np.random.binomial(1, 0.5, size=(x.size(0), x.size(1)))
            ).to(x.device)
            mask &= nan_mask
            x[~mask] = 0

        # conv encoder
        x = x.transpose(1, 2)  # B x Ch x T
        # x = self.repr_dropout(self.feature_extractor(x))  # B x Co x T
        x = self.feature_extractor(x)  # B x Co x T
        x = x.transpose(1, 2)  # B x T x Co
        return x


class DilatedConvEncoder(torch.nn.Module):
    def __init__(self, in_channels, channels, kernel_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            *[
                ConvBlock(
                    channels[i - 1] if i > 0 else in_channels,
                    channels[i],
                    kernel_size=kernel_size,
                    dilation=2**i,
                    final=(i == len(channels) - 1),
                )
                for i in range(len(channels))
            ]
        )

    def forward(self, x):
        return self.net(x)


class ConvBlock(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False):
        super().__init__()
        self.conv1 = SamePadConv(
            in_channels, out_channels, kernel_size, dilation=dilation
        )
        self.conv2 = SamePadConv(
            out_channels, out_channels, kernel_size, dilation=dilation
        )
        self.projector = (
            torch.nn.Conv1d(in_channels, out_channels, 1)
            if in_channels != out_channels or final
            else None
        )

    def forward(self, x):
        residual = x if self.projector is None else self.projector(x)
        x = torch.nn.functional.gelu(x)
        x = self.conv1(x)
        x = torch.nn.functional.gelu(x)
        x = self.conv2(x)
        return x + residual


class SamePadConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
        super().__init__()
        self.receptive_field = (kernel_size - 1) * dilation + 1
        padding = self.receptive_field // 2
        self.conv = torch.nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            padding=padding,
            dilation=dilation,
            groups=groups,
        )
        self.remove = 1 if self.receptive_field % 2 == 0 else 0

    def forward(self, x):
        out = self.conv(x)
        if self.remove > 0:
            out = out[:, :, : -self.remove]
        return out
