import json
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
import pickle

from language_utils import line_to_indices
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import os

# dir = os.path.dirname(os.getcwd())

with open('/home/tyin/Upcycled_objective/Upcycled_objective/data/sent140/data/embs.json', 'r') as inf:
    embs = json.load(inf)
# id2word = embs['vocab']
# word2id = {v: k for k,v in enumerate(id2word)}

# def process_x(raw_x_batch, max_words=25):
#     x_batch = [e[4] for e in raw_x_batch]
#     x_batch = [line_to_indices(e, word2id, max_words) for e in x_batch]
#     # x_batch = torch.LongTensor(x_batch)
#     return x_batch
#
# def process_y(raw_y_batch):
#     y_batch = [1 if e=='4' else 0 for e in raw_y_batch]
#     # y_batch = np.array(y_batch)
#
#     return y_batch

class LSTM(nn.Module):
    def __init__(self, dimension=128, device='cpu'):
        super(LSTM, self).__init__()

        word_emb = torch.Tensor(embs['emba']).to(device=device)
        self.embedding = nn.Embedding.from_pretrained(word_emb)
        self.dimension = dimension
        self.lstm = nn.LSTM(input_size=300,
                            hidden_size=dimension,
                            num_layers=2,
                            batch_first=True,
                            bidirectional=True)
        self.drop = nn.Dropout(p=0.5)

        self.fc = nn.Linear(2*dimension, 1)

    def forward(self, text, text_len=25):
        text_emb = self.embedding(text)

        # packed_input = pack_padded_sequence(text_emb, text_len, batch_first=True, enforce_sorted=False)
        output, _ = self.lstm(text_emb)
        # output, _ = pad_packed_sequence(packed_output, batch_first=True)

        out_forward = output[range(len(output)), text_len - 1, :self.dimension]
        out_reverse = output[:, 0, self.dimension:]
        out_reduced = torch.cat((out_forward, out_reverse), 1)
        text_fea = self.drop(out_reduced)

        text_fea = self.fc(text_fea)
        text_fea = torch.squeeze(text_fea, 1)
        text_out = torch.sigmoid(text_fea)
        text_out = torch.stack((1 - text_out, text_out)).T

        return text_out
#
# with open(dir + '/data/sent140/data/test/all_data_niid_3_keep_30_test_8.json', 'r') as inf:
#     cdata = json.load(inf)
# x = []
# y = []
# for key, item in cdata['user_data'].items():
#     x += item['x']
#     y += item['y']
# x = process_x(x, max_words=25)
# y = process_y(y)
# with open(dir+'/data/sent140/data/test/mytest_2.json', 'w') as fp:
#     json.dump([x, y], fp)


# with open(dir + '/data/sent140/data/train/all_data_niid_3_keep_30_train_8.json', 'r') as inf:
#     cdata = json.load(inf)
# i = 0
# user = 0
# user_groups = []
# x = []
# y = []
# train_x = []
# train_y = []
# for key, item in cdata['user_data'].items():
#     x += item['x']
#     y += item['y']
#     try:
#         user_groups[user] += len(item['y'])
#     except:
#         user_groups.append(len(item['y']))
#     i += 1
#     if i == 15:
#         i = 0
#         user += 1
#         train_x.append(x)
#         train_y.append(y)
#         x = []
#         y = []
# train_x.append(x)
# train_y.append(y)
#
#
# for i, x in enumerate(train_x):
#     x = process_x(x, max_words=25)
#     y = process_y(train_y[i])
#     train_x[i] = x
#     train_y[i] = y
#
# with open(dir+'/data/sent140/data/train/mytrain_2.json', 'w') as fp:
#     json.dump(list(zip(train_x, train_y)), fp)
