import torch
import torch.nn as nn
import torch.nn.functional as F
import random


class Encoder(nn.Module):
  def __init__(self, seq_len, n_features, embedding_dim=64):
    super(Encoder, self).__init__()
    self.seq_len, self.n_features = seq_len, n_features
    self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dim
    self.rnn1 = nn.LSTM(
      input_size=n_features,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )
    self.rnn2 = nn.LSTM(
      input_size=self.hidden_dim,
      hidden_size=embedding_dim,
      num_layers=1,
      batch_first=True
    )
  def forward(self, x):
    x = x.reshape((-1, self.seq_len, self.n_features)) #torch.Size([16, 128, 3])
    x, (_, _) = self.rnn1(x)
    x, (hidden_n, _) = self.rnn2(x)
    return hidden_n.squeeze()#.reshape((self.n_features, self.embedding_dim))


class Decoder(nn.Module):
  def __init__(self, seq_len, input_dim=64, n_features=1):
    super(Decoder, self).__init__()
    self.seq_len, self.input_dim = seq_len, input_dim
    self.hidden_dim, self.n_features = 2 * input_dim, n_features
    self.rnn1 = nn.LSTM(
      input_size=input_dim,
      hidden_size=input_dim,
      num_layers=1,
      batch_first=True
    )
    self.rnn2 = nn.LSTM(
      input_size=input_dim,
      hidden_size=self.hidden_dim,
      num_layers=1,
      batch_first=True
    )
    self.output_layer = nn.Linear(self.hidden_dim, n_features)
  def forward(self, x):
    x = x.repeat(self.seq_len,1)
    x = x.reshape((-1, self.seq_len, self.input_dim))
    x, (hidden_n, cell_n) = self.rnn1(x)
    x, (hidden_n, cell_n) = self.rnn2(x)
    x = x.reshape((-1, self.hidden_dim))
    x = self.output_layer(x)
    return x.reshape(-1, self.seq_len, self.n_features)



class RecurrentAutoencoder(nn.Module):

  def __init__(self, seq_len, n_features, embedding_dim=64, device = 'cuda:0'):
    super(RecurrentAutoencoder, self).__init__()
    self.encoder = Encoder(seq_len, n_features, embedding_dim).to(device)
    self.decoder = Decoder(seq_len, embedding_dim, n_features).to(device)
    self.projection = nn.Conv1d(seq_len, seq_len,
                                kernel_size=1, stride=1, bias=False)

  def creatMaskEvenSplit(self, x, part=6):
    b, l, c = x.shape
    blist = list(range(0, b))
    llist = list(range(0, l))
    clist = list(range(0, c))
    index = []
    for b_ind in blist:
      for l_ind in llist:
        for c_ind in clist:
          index.append([b_ind, l_ind, c_ind])

    slice_num = int(b * l * c / part)
    PartMask = []
    MaskX = []
    for i in range(part):
      slice = random.sample(index, slice_num)
      Mask = torch.ones(b, l, c, device=x.device)
      for s in slice:
        Mask[s[0], s[1], s[2]] = 0
        index.remove(s)
      Mask = (Mask == 0)
      Mask_temp = Mask.detach().cpu().numpy()

      PartMask.append(Mask)
      mask_x_temp = x.masked_fill(Mask, 0)
      mask_x_temp = mask_x_temp.detach().cpu().numpy()
      MaskX.append(x.masked_fill(Mask, 0))

    return MaskX, PartMask

  def forward(self, x):
    # part = 4
    # b, l, c = x.shape
    # point_processed_x = torch.zeros(b, l, c, device=x.device)
    #
    # point_mask_x, point_mask = self.creatMaskEvenSplit(x, part=part)
    #
    # for i in range(part):
    #   x = self.encoder(point_mask_x[i])
    #   mask_x_process = self.decoder(x)
    #
    #   unmask = (point_mask[i] == False)
    #
    #   mask_x_process = mask_x_process.masked_fill(unmask, 0)
    #   point_processed_x = point_processed_x + mask_x_process
    # #   #
    # point_processed_x = self.projection(point_processed_x)

    x = self.encoder(x)
    point_processed_x = self.decoder(x)


    return point_processed_x


