import random
import os
import numpy as np
import torch
import scipy.io
import matplotlib.pyplot as plt
import math
from CuDeRes import *
from sklearn.manifold import TSNE


def set_seed(seed=2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_data(path_list, scaling=1.0):
    collection = []
    for path in path_list:
        data = scipy.io.loadmat(path)["P"]
        data = torch.tensor(data, dtype=torch.float32).unsqueeze(0).unsqueeze(4)
        collection.append(data.permute(0, 2, 3, 1, 4))
    data = torch.cat(collection, dim=0)
    mean = data.mean(dim=(0, 2, 3, 4), keepdim=True)
    std = data.std(dim=(0, 2, 3, 4), keepdim=True)
    data = (data - mean) / std
    return data * scaling


def run(scaling=10, reservoir_size=50, radii=(0.9, 0.9, 0.9), regular=16):
    normal_models = np.load("./normal_model_depot.npy")

    # fitting test blocks (pipeline)
    transform = CuDeRes(1, reservoir_size=reservoir_size, radii=radii, regular=regular)
    test_path = ["../data/" + name for name in os.listdir("../data/")]
    rate = (math.e ** (-0.015), math.e ** (-0.01), math.e ** (-0.15))
    data = get_data(test_path, scaling)
    test_models = transform(data, decay=rate).numpy()

    # 2D visualization
    total_models = np.concatenate([normal_models, test_models], axis=0)
    total_label = np.concatenate([np.zeros(normal_models.shape[0]), np.ones(test_models.shape[0])], axis=0)

    tsne = TSNE(n_components=2, random_state=2025)
    x_tsne = tsne.fit_transform(total_models)

    plt.figure(figsize=(8, 6))

    num_classes = len(np.unique(total_label))
    colors = plt.cm.get_cmap('tab10', num_classes)

    for i in range(num_classes):
        indices = np.where(total_label == i)
        plt.scatter(x_tsne[indices, 0], x_tsne[indices, 1], label="Normal" if i == 0 else "Abnormal", color=colors(i))

    plt.title('2D Visualization of Normal and Abnormal Models')
    plt.xlabel('t-SNE component 1')
    plt.ylabel('t-SNE component 2')
    plt.xlim([-10, 25])
    plt.ylim([-25, 15])
    plt.legend()
    plt.show()

    return 0


if __name__ == "__main__":
    set_seed(2025)
    run(reservoir_size=50)


