import random
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from dataloader import build_corpus, test_one_sentence

corpus = build_corpus()
shown_size = 6

def train(epochs, model, training_set, test_set, optimizer, criterion, lr_scheduler, legal_test_path):
  model.train()
  for epoch in range(epochs):
    acc = test(test_dataset=test_set, legal_test_path=legal_test_path, model=model, shown_size=6)
    print('Epoch:{} Done, Test Acc:{}'.format(epoch, acc))
    model.train()
    for sentence_id, ori_sentence in enumerate(training_set):
        for i in range(shown_size, len(ori_sentence.split(' '))):
          input_sentence = ori_sentence.split(' ')[:i]
          #print("input_sentence", input_sentence)
          y = torch.tensor(corpus.get_idx(ori_sentence.split(' ')[i]), dtype=torch.long)
          optimizer.zero_grad()
          y_hat, _ = model(input_sentence)
          loss = criterion(y_hat, y)
          loss.backward()
          optimizer.step()
    if lr_scheduler is not None: lr_scheduler.step()
  return acc

@torch.no_grad()
def test(test_dataset, legal_test_path, model, shown_size=6, hidden=False):
  model.eval()
  total_sentence = 0
  right_sentence = 0
  hidden_states = []
  for sentence_id, sentence in enumerate(test_dataset):
    predicted_sentence = sentence.split(' ')[:shown_size] #print("input_sentence", predicted_sentence)
    _ , hidden_state = model(predicted_sentence) #print("hidden_state", hidden_state.shape)
    hidden_states.append([tuple(predicted_sentence), hidden_state])
    for i in range(shown_size, len(sentence.split(' '))):
      y_hat, _ = model(predicted_sentence)
      pred_next_word = corpus.get_word(torch.argmax(y_hat).item())
      predicted_sentence.append(pred_next_word)
      if pred_next_word == '.':
        break
    pred_result = test_one_sentence(legal_test_path[sentence_id], predicted_sentence)
    if pred_result:
      right_sentence += 1 # else: # print(predicted_sentence, sentence)
    total_sentence += 1
  #print("hidden_states", all_hidden_states)
  if hidden:
      return hidden_states
  return float(right_sentence) / total_sentence