import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import torch
from torch.utils.data import DataLoader
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans, AgglomerativeClustering
import time
import umap
import numpy as np
import os
import tqdm
from dataset import *
from models import *
from train_val_test_func import *
import torch.optim as optim
from TDHNODE import *
from utils import *

dataset_name = 'Hospital_2_12/'
UF_input_source = './data/' + dataset_name
MIMIC_input_source = './data/MIMIC/slide_window_20/'
input_source = UF_input_source
train_set = Hyper_Graph_Dataset_biomarker(input_source, 'train')
test_set = Hyper_Graph_Dataset_biomarker(input_source, 'test')

batch_size = 1690

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
feature_names = [
    'HbA1c_low', 'LDL_HighRisk', 'Hypertension', 'Obesity', 'Foot_ulcer',
    'Blindness_and_vision_loss', 'Visual_impairment', 'Congestive_heart_failure',
    'Nephropathy', 'Neuropathy', 'Retinopathy', 'Cerebrovascular_Disease', 'Stroke',
    'Depression', 'Hypoglycemia', 'HbA1c_high', 'BP_Diastolic_LowRisk',
    'Cardiac_revascularization', 'Atrial_fibrillation', 'Cancer', 'Ketoacidosis'
]
def extract_feature_names_with_time(idx):

	npy_array = train_dataloader.dataset.biomarkers[idx]  
	time = train_dataloader.dataset.times[idx]  
	first_occurrence = np.full(21, -1) 

	for time_step in range(npy_array.shape[0]): 
		row = npy_array[time_step]
		if np.all(row == -1): 
			continue
		indices = np.where(row == 1)[0]  
		for i in indices:
			if first_occurrence[i] == -1:  
				first_occurrence[i] = time_step

	valid_indices = np.where(first_occurrence != -1)[0]
	feature_names_sequence = [feature_names[i] for i in valid_indices]
	first_occurrence_array = first_occurrence[valid_indices]

	sorted_indices = np.argsort(first_occurrence_array)
	feature_names_sequence = [feature_names_sequence[i] for i in sorted_indices]
	first_occurrence_array = first_occurrence_array[sorted_indices]

	time_sequence = time[first_occurrence_array] 

	print("biomarkers: ", feature_names_sequence)
	print("timestamps: ", time_sequence)

# embedding = np.load("./embedding_saved/hyper_embedding2025-02-13-18-33-02.npy", allow_pickle=True)
#
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-17-23-11-37.npy", allow_pickle=True) # MIMIC
#
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-17-23-18-16.npy", allow_pickle=True)

# embedding = np.load("./embedding_saved/patient_embedding_2025-03-19-13-04-41.npy", allow_pickle=True)
#
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-19-17-38-48.npy", allow_pickle=True)

# Frozen:
# TD, test_set (242,19,128)
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-25-00-34-39.npy", allow_pickle=True)
# TD, train_set (1690, 19, 128)
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-25-00-34-54.npy", allow_pickle=True)

# TD, test_set (242,19,21)
# embedding = np.load("./embedding_saved/patient_embedding_2025-03-25-00-58-04.npy", allow_pickle=True)
# TD, train_set (1690, 19, 21)
embedding = np.load("./embedding_saved/patient_embedding_2025-03-25-00-58-16.npy", allow_pickle=True)

# take the last one prediction
mask = train_dataloader.dataset.mask[:, 1:]

K = 1
result = np.zeros((mask.shape[0], K, embedding.shape[2]))

for i in range(mask.shape[0]):
    true_indices = np.where(mask[i])[0]
    if len(true_indices) >= K:
        selected = true_indices[-K:]
    else:

        pad = np.full(K - len(true_indices), true_indices[0] if len(true_indices) > 0 else 0)
        selected = np.concatenate([pad, true_indices])
    result[i] = embedding[i, selected]

# result = result[:,0,:]


print(f"result shape: {result.shape}")
num_sample = embedding.shape[0]
embedding_flattened = result.reshape(num_sample, -1)

# scaler = StandardScaler()
# embedding_flattened = scaler.fit_transform(embedding_flattened)

tsne = TSNE(n_components=2, perplexity=15, random_state=42)
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
# embedding_2d = umap_model.fit_transform(embedding_flattened)

embedding_2d = tsne.fit_transform(embedding_flattened)

save_idx_file = True
# cluster the embedding_flattened
n_clusters = 3
# --- K-Means ---
# kmeans = KMeans(n_clusters=n_clusters, random_state=24)
# cluster_labels = kmeans.fit_predict(embedding_flattened)

# # --- Hierarchical clustering ---
hier = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
cluster_labels = hier.fit_predict(embedding_flattened)

plt.figure(figsize=(8, 6))
special_indices = []  
other_indices = list(set(range(len(embedding_2d))) - set(special_indices))
id_cluster = [[] for _ in range(n_clusters)]  

for cluster_id in range(n_clusters):
    idx = [i for i in other_indices if cluster_labels[i] == cluster_id]
    id_cluster[cluster_id] = idx 
    plt.scatter(
        embedding_2d[idx, 0],
        embedding_2d[idx, 1],
        alpha=0.6,
        s=15,
        label=f"Cluster {cluster_id + 1}"
    )

# if special_indices:
#     plt.scatter(
#         embedding_2d[special_indices, 0],
#         embedding_2d[special_indices, 1],
#         color='red',
#         alpha=0.8,
#         s=30,
#         label="Highlighted Points"
#     )

id_cluster_np = np.array(id_cluster, dtype=object)
if save_idx_file:
	np.save(f'./case_study/id_cluster_{len(id_cluster_np)}_embedding_{embedding.shape[0]}_{embedding.shape[2]}.npy', id_cluster_np)
plt.xlabel('t-SNE Component 1', fontsize=18)
plt.ylabel('t-SNE Component 2', fontsize=18)
plt.title(f't-SNE Visualization of Patient Embeddings', fontsize=22)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.legend(fontsize=14)
plt.tight_layout()
# plt.savefig('tsne_clustering3.png', dpi=500)
plt.show()
