import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
from six.moves import cPickle as pkl
import argparse
import numpy as np

def tokenize_sentence(sentence):
    return [int(token) for token in sentence]

# Create a custom Dataset
class TextCompletionDataset(Dataset):
    def __init__(self, sentences):
        self.sentences = [tokenize_sentence(sentence.split()) for sentence in sentences]

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return torch.tensor(self.sentences[idx])

# Define the LSTM model
class LSTMTextCompletion(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(LSTMTextCompletion, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        output = self.linear(output)
        return output
    

if __name__ == "__main__":
    
    DUMP = "DIR/TO/CACHE"

    parser = argparse.ArgumentParser()
    parser.add_argument('--v', type=int, default=10)
    parser.add_argument('--t', type=int, default=3)
    parser.add_argument('--n', type=int, default=20_000)
    parser.add_argument('--h', type=int, default=16)
    parser.add_argument('--trials', type=int, default=5)
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()
    print(args)

    np.random.seed(args.seed)

    with open(os.path.join(DUMP, "v={}_t={}_n={}.pkl".format(args.v, args.t, args.n)), "rb") as f:
        sentences = pkl.load(f)

    # Hyperparameters
    embedding_dim = 32
    hidden_dim = args.h
    batch_size = 16
    num_epochs = 30

    # Prepare the data and train the model
    dataset = TextCompletionDataset(sentences)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for tr in range(args.trials):
        model = LSTMTextCompletion(args.v, embedding_dim, hidden_dim).to(0)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for batch in dataloader:
                optimizer.zero_grad()
                output = model(batch[:, :-1].to(0))
                loss = criterion(output.transpose(1, 2), batch[:, 1:].to(0))
                loss.backward()
                optimizer.step()    
            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
            
        torch.save(model, os.path.join(DUMP, "v={}_t={}_n={}_h={}_trial={}.pt".format(args.v, args.t, args.n, args.h, tr)))