import math
import random
import torch
import torch.nn as nn
import numpy as np

from einops import rearrange
from pulse.conv import Conv1D

random.seed(1)
np.random.seed(1)
torch.manual_seed(1)


class InitConditionEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.in_dim = config.data_args.input_dims
        self.context_dim = config.encoder_args.emb_dim
        self.recon_hidden_dim = config.model_args.recon_args.hidden_dim
        self.recon_gru_num_layers = config.model_args.recon_args.num_layers
        self.in_proj_kernel_size = config.model_args.init_args.in_proj_kernel_size
        self.in_proj_dilation = config.model_args.init_args.in_proj_dilation

        self.init_norm_bool = config.model_args.init_args.init_norm
        self.init_hidden_dim = config.model_args.init_args.hidden_dim

        # print(math.gcd(self.recon_hidden_dim, self.init_hidden_dim))

        self.init_proj = Conv1D(
            self.in_dim,
            self.init_hidden_dim,
            self.recon_hidden_dim,
            kernel_size=self.in_proj_kernel_size,
            dilation=self.in_proj_dilation,
            # groups=self.recon_hidden_dim
            groups=math.gcd(self.recon_hidden_dim, self.init_hidden_dim),
            padding_mode="replicate",
        )

        if self.init_norm_bool:
            self.init_norm = nn.LayerNorm(self.recon_hidden_dim)
            # self.init_norm = nn.RMSNorm(self.recon_hidden_dim*self.recon_gru_num_layers)

    def forward(self, x, sample_init=False, sample_right_boundary=20):

        b, t, c = x.shape

        if sample_init:
            # sample a random starting point to encode the initial hidden state
            start_ix = torch.randint(t - sample_right_boundary, (b,))
        else:
            start_ix = torch.zeros(b).long()

        n_steps = t - start_ix - 1

        # x: b, t, n
        x0 = self.init_proj(x)  # b, t, n
        x0 = x0[torch.arange(b), start_ix].squeeze()  # b, n

        # if self.init_norm_bool:
        # x0 = self.init_norm(x0)

        x0 = rearrange(x0, "b (i j) -> i b j", i=1).contiguous()
        x0 = torch.cat(
            [x0]
            + [
                torch.zeros_like(x0)
                for i in range(self.config.model_args.recon_args.num_layers - 1)
            ]
        )  # pad remaining layers with zeros

        return x0, start_ix, n_steps


class SharedInitConditionEncoder(InitConditionEncoder):
    """
    takes in context_unpooled, b, t, z and outputs h0 for recon_net
    h0: (num_layers, b, hidden_size)
    """

    def __init__(self, config):
        super().__init__(config)
        self.init_proj = Conv1D(
            self.context_dim,
            self.init_hidden_dim,
            self.recon_hidden_dim,
            kernel_size=self.in_proj_kernel_size,
            dilation=self.in_proj_dilation,
            # groups=self.recon_hidden_dim
            groups=math.gcd(self.recon_hidden_dim, self.init_hidden_dim),
            padding_mode="replicate",
        )
