import random
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from dataloader import build_corpus

""" Hyper parameters of Model """
input_size=10
hidden_size=64 #12
num_layers=1

corpus = build_corpus()

class NextWordPredSSM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, voca_size):
        super().__init__()
        self.word_embedded = nn.Embedding(num_embeddings=voca_size, embedding_dim=input_size)
        self.hidden_size = hidden_size
        # Weight matrices and biases
        self.W_ih = nn.Linear(input_size, hidden_size, bias=False)   # Includes bias term b_h
        self.W_hh = nn.Linear(hidden_size, hidden_size, bias=False)  # No bias term here
        self.fc = nn.Linear(hidden_size, voca_size)  # Output layer

    def forward(self, sentence):
        idx_of_sentence = self.sentence_to_index(sentence)
        idx_of_sentence = torch.tensor(idx_of_sentence, dtype=torch.long)
        # Get embeddings for the input sentence
        embed_of_sentence = self.word_embedded(idx_of_sentence)  # Shape: [length, input_size]
        # Initialize hidden state h_0
        h_t = torch.zeros(self.hidden_size)
        h_t = h_t.to(embed_of_sentence.device)
        # Process each time step
        for t in range(len(embed_of_sentence)):
            x_t = embed_of_sentence[t]
            # Linear hidden state update (no activation function)
            h_t = self.W_ih(x_t) + self.W_hh(h_t)
        # Apply ReLU non-linearity before the output layer
        h_relu = torch.relu(h_t) #h_relu = h_t
        predicted = self.fc(h_relu)  # Output logits
        return predicted, h_t

    def sentence_to_index(self, sentence):
        idx_of_sentence = [corpus.get_idx(word) for word in sentence]
        return idx_of_sentence