import torch
import torch.nn as nn
from einops import rearrange

from pulse.conv import Conv1D
from utils.common import get_true_rolled


class TimeVaryingModule(nn.Module):
    """
    this module takes the output of and encoder and extracts a 1d time-varying signal that is used to model nonstationarity in the signal
    """

    def __init__(self, config):
        super(TimeVaryingModule, self).__init__()

        kernel_size = config.model_args.time_vary_args.tv_kernel_size
        self.conv = Conv1D(
            config.encoder_args.emb_dim,
            config.model_args.time_vary_args.tv_dim,
            config.model_args.time_vary_args.tv_dim,
            kernel_size=3,
            dilation=1,
            groups=config.model_args.time_vary_args.tv_dim,
        )

        assert (
            config.data_args.subseq_size % config.model_args.time_vary_args.pool_denom
            == 0
        ), "subseq_size must be divisible by pool_denom"
        self.config = config
        self.tv_pool = nn.AdaptiveMaxPool1d(
            config.data_args.subseq_size // config.model_args.time_vary_args.pool_denom
        )

    def forward(self, x):
        # x: b t c, this is context vector.

        x = self.conv(x)
        o = x
        o = self.tv_pool(rearrange(o, "b t c -> b c t"))
        o = rearrange(
            torch.repeat_interleave(
                o, self.config.model_args.time_vary_args.pool_denom, dim=-1
            ),
            "b c t -> b t c",
        )
        return x, o

    def shift_start(self, x, start_ix):
        roll = get_true_rolled(x, start_ix)
        return roll
