import random
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.decomposition import PCA
from dataloader import generate_dataset, build_corpus
from models import NextWordPredSSM
from trainer import train, test
import os
train_size=40
test_size=400
SEED=1234
corpus = build_corpus()
input_size=10
hidden_size=145 #12
num_layers=1
model = NextWordPredSSM(input_size, hidden_size, num_layers, len(corpus))
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Model has {} parameters to be trained.'.format(num_params)) #""" Hyper parameters of Training """
shown_size = 6
training_set, legal_training_path, test_set, legal_test_path = generate_dataset(train_size=40, test_size=800)
print('Training set has {} sentences. Test set has {} sentences'.format(len(training_set), len(test_set)))
model = NextWordPredSSM(input_size, hidden_size, num_layers, len(corpus))
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Model has {} parameters to be trained.'.format(num_params))
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)
epochs = 10
final_acc = train(epochs, model, training_set, test_set, optimizer, criterion, lr_scheduler, legal_test_path)
if not os.path.exists('Plots'):
    os.makedirs('Plots')
hidden_weights = model.W_hh.weight #print("hidden_weights", hidden_weights)
hidden_weights_np = hidden_weights.detach().numpy() #print("hidden_weights_np", hidden_weights_np)
eigenvalues = np.linalg.eigvals(hidden_weights_np)
plt.figure(figsize=(5, 5))
plt.plot(eigenvalues.real, eigenvalues.imag, 'o', markersize=5, label='Eigenvalues')
theta = np.linspace(0, 2 * np.pi, 100) # Plotting the unit circle
x_circle = np.cos(theta)
y_circle = np.sin(theta)
plt.plot(x_circle, y_circle, color='red', linestyle='--', label='Unit Circle (r=1)')
plt.axhline(0, color='black', lw=0.5, ls='--') # Adding horizontal and vertical lines at 0
plt.axvline(0, color='black', lw=0.5, ls='--')
plt.title('Eigenvalues of the Hidden-Hidden Weight Matrix')
plt.xlabel('Real Part')
plt.ylabel('Imaginary Part')
plt.grid()
plt.axis('equal') #Ensure equal scaling
plt.legend()
plt.savefig('Plots/eigenvalues_plot_ssm.jpg', format='jpg', dpi=800)
plt.show()
all_hidden_states = []
for shown_size in range(6, 9):
  hidden_states = test(test_dataset=test_set, legal_test_path=legal_test_path, model=model, shown_size=shown_size, hidden=True)
  all_hidden_states = all_hidden_states + hidden_states
print("hidden states", type(all_hidden_states), len(all_hidden_states)) #, all_hidden_states)
hidden_vectors = []
labels = []
label_counts = {}
for i, hidden_state in enumerate(all_hidden_states):
    ngram = hidden_state[0][-3:]
    if ngram in [('and', 'put', 'on'), ('put', 'on', 'his'), ('on', 'his', 'jeans'), ('eight', 'o’clock', 'and'), ('o’clock', 'and', 'was'), ('and', 'was', 'too')]:
      #print("special label", i, ngram)
      hidden_vectors.append(hidden_state[1].squeeze().detach().numpy())  # Convert to numpy array and squeeze dimensions
      labels.append(ngram)  # Use the last 4 words as a label
      if ngram in label_counts:
        label_counts[ngram] += 1
      else:
        label_counts[ngram] = 1
hidden_vectors = np.array(hidden_vectors)
pca = PCA(n_components=2)
hidden_vectors_2d = pca.fit_transform(hidden_vectors)
unique_labels = {label: idx for idx, label in enumerate(set(labels))}
print("Unique_labels", unique_labels)
colors = [unique_labels[label] for label in labels]
plt.figure(figsize=(7, 4))
scatter = plt.scatter(hidden_vectors_2d[:, 0], hidden_vectors_2d[:, 1], c=colors, cmap='viridis', alpha=0.7)
# Create a custom legend
handles = []
for label, idx in unique_labels.items():
    handles.append(plt.Line2D([0], [0], marker='o', color='w', label=label,
                               markerfacecolor=scatter.cmap(scatter.norm(idx)), markersize=10))
# Place the legend outside the plot
plt.legend(handles=handles, title="Labels", bbox_to_anchor=(1.05, 1), loc='upper left') #'upper left', bbox_to_anchor=(1.05, 1)
plt.title('N-gram embeddings')
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.grid()
plt.tight_layout()  # Adjust layout to fit everything
plt.savefig('Plots/embeddings_plot_ssm.jpg', format='jpg', dpi=800)
plt.show()