import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from pulse.mlp import MLP


import random
import numpy as np

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)


class DynamicAugmentations(nn.Module):
    def __init__(
        self,
        config,
    ):
        super(DynamicAugmentations, self).__init__()

        # add_init_noise
        self.config = config
        self.subseq_size = self.config.data_args.subseq_size
        self.stretch_range = self.config.model_args.augmentation_args.stretch_range

        self.hidden_dim = self.config.model_args.recon_args.hidden_dim

        self.projector = nn.Identity()

    def stretch(self, x, stretch_factor=None):
        """
        x: b, t, c
        """

        b, _, _ = x.shape

        if stretch_factor is None:
            stretch_factor = (
                torch.rand(1) * (self.stretch_range[1] - self.stretch_range[0])
                + self.stretch_range[0]
            )

        stretch_factor = stretch_factor.to(self.config.device)
        target_len = (stretch_factor * self.subseq_size).int()

        x = rearrange(x, "b t c -> b c t")
        x_stretch = F.interpolate(x, target_len[0], mode="linear", align_corners=False)

        x_stretch = rearrange(x_stretch, "b c t -> b t c")[:, : self.subseq_size]
        return x_stretch, stretch_factor

    def add_noise(self, x, std=1):
        return x + torch.randn_like(x) * std

    def get_recon_inputs(self, context, n_steps):
        """
        context: (b, c)

        """
        context = rearrange(context, "b c -> b 1 c")
        context = self.projector(context)

        b, _, n = context.shape
        recon_inputs = context.expand(b, n_steps, n).contiguous()

        return recon_inputs
