import torch

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
def test(dataloader, model, loss_fn,device = 'cpu',is_testing_generation = False,is_plot_t_sne = False):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    pred_array = []
    y_array = []
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            pred_array.append(pred.detach().cpu().numpy())
            y_array.append(y.cpu().numpy())
            test_loss += loss_fn(pred, y).item()
            if not is_testing_generation:
                # compare the predicted class with the true class
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    pred_array = np.concatenate(pred_array,axis=0)
    y_array = np.concatenate(y_array, axis=0)
    if is_plot_t_sne:
        return 100*correct if not is_testing_generation else test_loss, pred_array, y_array
    return 100*correct if not is_testing_generation else test_loss
def plot_t_sne(x, y, save_path = 't-SNE_visualization_decoder.png"'):
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate=200)
    X_2d = tsne.fit_transform(x,y)
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap='tab20', s=10, alpha=0.8)
    # Add legend (one color per target)
    legend = plt.legend(*scatter.legend_elements(num=40), title="Targets", bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.gca().add_artist(legend)

    plt.title("t-SNE visualization (colored by target)")
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()