import os
os.chdir('.\\')
import numpy as np


# datasets = np.random.rand(100, 65)
file_source_name = "best_cm_RrED.txt"
#file_source_name = "start_cm.txt"
label_source_name = "validation_list.txt"


# np.savetxt(file_name, datasets)

max_indices=[]
index=[]

with open(label_source_name, 'r') as file:
    for line in file:
        max_indices.append(int(line.split(" ")[1]))


loaded_source_datasets = np.loadtxt(file_source_name)
#max_indices = [tensor.argmax().item() for tensor in loaded_datasets]
print(max_indices[:50])
#labels = np.random.choice(calss,len(loaded_datasets))
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import seaborn as sns

printindex=[i for i in range(12)]


source_label=[]
target_label=[]
array1=[]
array2=[]
print(len(loaded_source_datasets),len(max_indices))
loaded_source_datasets=loaded_source_datasets[:len(max_indices)]
for i in range(len(loaded_source_datasets)):
    if max_indices[i] in printindex:
        source_label.append(max_indices[i])
        array1.append(loaded_source_datasets[i])
        
        
array1 = np.array(array1)
array1 = np.argmax(array1, axis=1)
# array1 = np.array(loaded_datasets)## np.array([tensor.numpy() for tensor in loaded_datasets])
# array2 = np.array(loaded_source_datasets)
len_data=len(array1)
print(len_data)

confusion_mat = confusion_matrix(source_label,array1)
annot_array = np.empty_like(confusion_mat, dtype=object)
for i in range(confusion_mat.shape[0]):
    for j in range(confusion_mat.shape[1]):
        if i == j:
            annot_array[i, j] = str(confusion_mat[i, j])
        else:
            annot_array[i, j] = ""


sns.heatmap(confusion_mat, annot=annot_array, cmap="Greens", fmt="", xticklabels=False, yticklabels=False,vmin=0, vmax=10000, annot_kws={'size': 9})
# heatmap.set_xticks([0.5, 15.5, 30.5])
# heatmap.set_xticklabels([0, 15, 30], rotation=0)
# heatmap.set_yticks([0.5, 15.5, 30.5])
# heatmap.set_yticklabels([0, 15, 30], rotation=0)

# datasets = np.random.rand(100, 65)

plt.show()
