import pandas as pd
import torch
import numpy as np
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from sklearn.model_selection import GroupShuffleSplit
def build_graphs(subject_ids, code_graph, lab_graph, image_graph, code_embeddings, lab_embeddings, image_embeddings, k, code_similarity_threshold, lab_similarity_threshold, image_similarity_threshold):
    temporal_edge_count = 0
    similarity_edge_count_code = 0
    similarity_edge_count_lab = 0
    similarity_edge_count_image = 0

    for i, subject_id in enumerate(subject_ids):
        code_graph.add_node(i, subject_id=subject_id, embedding=code_embeddings[i])
        lab_graph.add_node(i, subject_id=subject_id, embedding=lab_embeddings[i])
        image_graph.add_node(i, subject_id=subject_id, embedding=image_embeddings[i])

    # Adding temporal edges for the same subject_id
    for subject_id in np.unique(subject_ids):
        indices = np.where(subject_ids == subject_id)[0]
        for i in range(len(indices) - 1):
            code_graph.add_edge(indices[i], indices[i + 1], edge_type='temporal')
            lab_graph.add_edge(indices[i], indices[i + 1], edge_type='temporal')
            image_graph.add_edge(indices[i], indices[i + 1], edge_type='temporal')
            temporal_edge_count += 1

    print('temporal edges graph done')

    # code_embeddings_flat = np.squeeze(code_embeddings, axis=1)
    # lab_embeddings_flat = np.squeeze(lab_embeddings, axis=1)
    # image_embeddings_flat = np.squeeze(image_embeddings, axis=1)

    code_cos_sim = cosine_similarity(code_embeddings)
    lab_cos_sim = cosine_similarity(lab_embeddings)
    image_cos_sim = cosine_similarity(image_embeddings)

    for i in range(len(subject_ids)):
        top_k_code_indices = np.argsort(-code_cos_sim[i, :])[:k+1]  # +1 because the node itself is the most similar
        top_k_lab_indices = np.argsort(-lab_cos_sim[i, :])[:k+1]
        top_k_image_indices = np.argsort(-image_cos_sim[i, :])[:k+1]

        for j in top_k_code_indices:
            if i != j and subject_ids[i] != subject_ids[j] and code_cos_sim[i, j] > code_similarity_threshold:
                code_graph.add_edge(i, j, edge_type='similarity', weight=code_cos_sim[i, j])
                similarity_edge_count_code += 1

        for j in top_k_lab_indices:
            if i != j and subject_ids[i] != subject_ids[j] and lab_cos_sim[i, j] > lab_similarity_threshold:
                lab_graph.add_edge(i, j, edge_type='similarity', weight=lab_cos_sim[i, j])
                similarity_edge_count_lab += 1

        for j in top_k_image_indices:
            if i != j and subject_ids[i] != subject_ids[j] and image_cos_sim[i, j] > image_similarity_threshold:
                image_graph.add_edge(i, j, edge_type='similarity', weight=image_cos_sim[i, j])
                similarity_edge_count_image += 1

    # Print summary of the graph
    print(f"Code Graph: Nodes={code_graph.number_of_nodes()}, Temporal Edges={temporal_edge_count}, Similarity Edges={similarity_edge_count_code}")
    print(f"Lab Graph: Nodes={lab_graph.number_of_nodes()}, Temporal Edges={temporal_edge_count}, Similarity Edges={similarity_edge_count_lab}")
    print(f"Image Graph: Nodes={image_graph.number_of_nodes()}, Temporal Edges={temporal_edge_count}, Similarity Edges={similarity_edge_count_image}")
from sklearn.preprocessing import normalize

def load_graphs():
    graph = {}

    code_threshold = 0.8
    lab_threshold = 0.8
    image_threshold =0.9
    k = 100

    train_data =  torch.load('cbica/home/NAME/project/downsampled_data/train_data_20.pt')
    train_subject_id = np.array([entry['patient_id'] for entry in train_data])
    train_code_embeddings = np.array([entry['code_embeddings'] for entry in train_data])
    # train_code_embeddings = np.array([entry['note_embeddings'] for entry in train_data])
    # lab_embeddings = np.array([entry['labs'] for entry in train_data])
    train_code_embeddings = normalize(train_code_embeddings.squeeze(1), norm='l2')
    for entry in train_data:
        if entry['labs'] is not None:
            lab_embedding_shape = np.shape(entry['labs'])  # Get the shape of a valid 'labs' entry
            break

    # Create lab_embeddings array, replacing missing values with a zero vector of the same shape
    train_lab_embeddings = np.array([
        np.expand_dims(np.nan_to_num(entry['labs']), axis=0) if entry['labs'] is not None else np.expand_dims(np.zeros(lab_embedding_shape), axis=0)
        for entry in train_data
    ])
    # train_lab_embeddings = (train_lab_embeddings - np.min(train_lab_embeddings)) / (np.max(train_lab_embeddings) - np.min(train_lab_embeddings))
    train_lab_embeddings = normalize(train_lab_embeddings.squeeze(1), norm='l2')
    # Create image_embeddings array, averaging the list of tensors in each entry
    # train_image_embeddings = np.array([
    #     np.expand_dims(torch.mean(torch.stack(entry['image_embeddings']), dim=0).numpy(), axis=0) if entry['image_embeddings'] else np.expand_dims(np.zeros((1,)), axis=0)
    #     for entry in train_data
    # ])
    train_image_embeddings = np.array([
        np.expand_dims(entry['image_embeddings'][-1].numpy(), axis=0) if entry['image_embeddings'] else np.expand_dims(np.zeros((1,)), axis=0)
        for entry in train_data
    ])
    train_image_embeddings = normalize(train_image_embeddings.squeeze(1), norm='l2')

    test_data =  torch.load('cbica/home/NAME/project/downsampled_data/test_data_20.pt')
    test_subject_id = np.array([entry['patient_id'] for entry in test_data])
    test_code_embeddings = np.array([entry['code_embeddings'] for entry in test_data])
    test_code_embeddings =normalize(test_code_embeddings.squeeze(1), norm='l2')
    # train_code_embeddings = np.array([entry['note_embeddings'] for entry in train_data])
    # lab_embeddings = np.array([entry['labs'] for entry in train_data])
    for entry in test_data:
        if entry['labs'] is not None:
            lab_embedding_shape = np.shape(entry['labs'])  # Get the shape of a valid 'labs' entry
            break

    # Create lab_embeddings array, replacing missing values with a zero vector of the same shape
    test_lab_embeddings = np.array([
        np.expand_dims(np.nan_to_num(entry['labs']), axis=0) if entry['labs'] is not None else np.expand_dims(np.zeros(lab_embedding_shape), axis=0)
        for entry in test_data
    ])
    # test_lab_embeddings = (test_lab_embeddings - np.min(test_lab_embeddings)) / (np.max(test_lab_embeddings) - np.min(test_lab_embeddings))

    test_lab_embeddings = normalize(test_lab_embeddings.squeeze(1), norm='l2')

    # Create image_embeddings array, averaging the list of tensors in each entry
    # test_image_embeddings = np.array([
    #     np.expand_dims(torch.mean(torch.stack(entry['image_embeddings']), dim=0).numpy(), axis=0) if entry['image_embeddings'] else np.expand_dims(np.zeros((1,)), axis=0)
    #     for entry in test_data
    # ])
    test_image_embeddings = np.array([
        np.expand_dims(entry['image_embeddings'][-1].numpy(), axis=0) if entry['image_embeddings'] else np.expand_dims(np.zeros((1,)), axis=0)
        for entry in test_data
    ])
    test_image_embeddings = normalize(test_image_embeddings.squeeze(1), norm='l2')
    train_code_graph = nx.Graph()
    train_lab_graph = nx.Graph()
    train_image_graph = nx.Graph()
    build_graphs(train_subject_id, train_code_graph, train_lab_graph, train_image_graph, train_code_embeddings, train_lab_embeddings, train_image_embeddings, k, code_threshold, lab_threshold, image_threshold)

    test_code_graph = nx.Graph()
    test_lab_graph = nx.Graph()
    test_image_graph =nx.Graph()
    build_graphs(test_subject_id, test_code_graph, test_lab_graph, test_image_graph, test_code_embeddings, test_lab_embeddings, test_image_embeddings, k, code_threshold, lab_threshold, image_threshold)


    all_subject_id = np.concatenate([train_subject_id, test_subject_id], axis=0)
    all_code_embeddings = np.concatenate([train_code_embeddings, test_code_embeddings], axis=0)
    all_lab_embeddings = np.concatenate([train_lab_embeddings, test_lab_embeddings], axis=0)
    all_image_embeddings = np.concatenate([train_image_embeddings, test_image_embeddings], axis=0)
    all_code_graph = nx.Graph()
    all_lab_graph = nx.Graph()
    all_image_graph =nx.Graph()
    # Build graphs for test data using the same embeddings
    build_graphs(all_subject_id, all_code_graph, all_lab_graph, all_image_graph, all_code_embeddings, all_lab_embeddings, all_image_embeddings, k, code_threshold, lab_threshold, image_threshold)

    graph['train_code'] = train_code_graph
    graph['train_lab'] = train_lab_graph
    graph['train_image'] = train_image_graph
    graph['test_code'] = test_code_graph
    graph['test_lab'] = test_lab_graph
    graph['test_image'] = test_image_graph
    graph['all_code'] = all_code_graph
    graph['all_lab'] =all_lab_graph
    graph['all_image'] = all_image_graph
    return graph


def main():
    return load_graphs()

if __name__ == "__main__":
    print('main here')
    main()


