import random
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from dataloader import generate_dataset, build_corpus
from models import NextWordPredSSM
from trainer import train
train_size=40
test_size=400
SEED=1234
corpus = build_corpus()
input_size=10
hidden_size=145 #12
num_layers=1
shown_size = 6
datasizes = [40] #[20,40,60,80]
results = { dsize: [] for dsize in datasizes }
for datasize in datasizes:
  for _ in range(5): #5
    training_set, legal_training_path, test_set, legal_test_path = generate_dataset(train_size=datasize, test_size=800)
    print('Training set has {} sentences. Test set has {} sentences'.format(len(training_set), len(test_set)))
    model = NextWordPredSSM(input_size, hidden_size, num_layers, len(corpus))
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Model has {} parameters to be trained.'.format(num_params))
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
    epochs = 10
    final_acc = train(epochs, model, training_set, test_set, optimizer, criterion, lr_scheduler, legal_test_path)
    results[datasize].append(final_acc)
print(results)