import random
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

num_classes = 60
classes = [[]]
classes = classes + [['Harry', 'Ron', 'Sirius', 'Hagrid', 'Fred', 'George', 'Neville', 'Draco', 'Albus', 'Snape']] # 1
classes = classes + [['woke']] + [['at']] # 2  3
classes = classes + [['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'eleven', 'twelve']] # 4
classes = classes + [['o’clock']] + [['and']] + [['was']] + [['too']] # 5  6  7  8
classes = classes + [['excited', 'angry', 'elated', 'agitated', 'nervous', 'troubled', 'upset']] # 9
classes = classes + [['to']] + [['go']] + [['back']] + [['sleep', 'bed']] # 10  11  12  13
classes = classes + [['He']] + [['got', 'stood', 'rose']] + [['up']] + [['put', 'pulled']] + [['on']] + [['his']] # 14 15 16 17 18 19
classes = classes + [['jeans', 'trousers', 'pants', 'denims', 'slacks']] + [['because']] + [['didn’t']] + [['want']] # 20 21 22 23
classes = classes + [['walk', 'run', 'jog', 'stroll', 'sprint', 'saunter', 'amble']] # 24
classes = classes + [['into']] + [['the']] + [['station', 'platform', 'house', 'building', 'apartment']] # 25 26 27
classes = classes + [['in']] + [['robes', 'pyjamas', 'shorts', 'underpants', 'boxers', 'trunks']] # 28 29
classes = classes + [['.']] # 30
paths = [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 10, 13], [14, 15, 16, 6, 17, 18, 19, 20, 21, 14, 22, 23, 10, 24, 25, 26, 27, 28, 19, 29]]
""" Add a period to the end of sentence! """
for path in paths:
  assert len(classes[-1]) == 1 and '.' in classes[-1]
  # print('Ori path: ', path)
  path.append(len(classes) - 1)
  # print('After change: ', path)

def generate_sentence(path):
  # path: [1,2,3,...]
  tmp_class = classes[path[0]]
  for i in range(1, len(path)):
    tmp_class = product_classes(tmp_class, classes[path[i]])
  return tmp_class

def product_classes(class1, class2):
  result = []
  for word1 in class1:
    for word2 in class2:
      result.append(word1 + ' ' + word2)
  return result

""" Hyper parameters of Dataset """
train_size=40
test_size=400
SEED=1234

def generate_dataset(train_size=200, test_size=800, paths = paths):
  datasets_sentence = []
  datasets_path = []
  for path in paths:
    for sen in generate_sentence(path):
      datasets_sentence.append(sen)
      datasets_path.append(path) #"""Training Set & Test Set"""
  assert train_size + test_size < len(datasets_sentence)
  print('using seed:', SEED, end='  ')
  random.seed(SEED)
  datasets = list(zip(datasets_sentence, datasets_path))
  random.shuffle(datasets)
  datasets_sentence = []
  datasets_path = []
  for data, path in datasets:
    datasets_sentence.append(data)
    datasets_path.append(path)
  training_set = datasets_sentence[:train_size]
  legal_training_path = datasets_path[:train_size]
  test_set = datasets_sentence[train_size:train_size+test_size]
  legal_test_path = datasets_path[train_size:train_size+test_size]
  return training_set, legal_training_path, test_set, legal_test_path

class Corpus():
  def __init__(self,):
    self.idx = 0
    self.word2idx = {}
    self.idx2word = {}
  def add_word(self, word):
    if self.get_idx(word) is None:
      self.word2idx[word] = self.idx
      self.idx2word[self.idx] = word
      self.idx += 1
  def __len__(self):
    return len(self.word2idx)
  def get_word(self, idx):
    return self.idx2word.get(idx, None)
  def get_idx(self, word):
    return self.word2idx.get(word, None)

def build_corpus(paths = paths, classes = classes):
    corpus = Corpus()
    for path in paths:
        for class_idx in path:
            for word in classes[class_idx]:
                corpus.add_word(word)
    corpus.add_word('empty')
    return corpus

def test_one_sentence(legal_path, pred_words):
  if len(legal_path) != len(pred_words):
    return False
  for loc in range(len(legal_path)):
    if not pred_words[loc] in classes[legal_path[loc]]:
      return False
  return True