import os
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.utils import k_hop_subgraph
from transformers import AutoTokenizer, AutoModel

from models import GNN

from tqdm import tqdm

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
text_encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

def text_embedding(dataset, texts):
    text_emb_file = f'processed_data/{dataset}_text_emb.pt'
    if not os.path.exists(text_emb_file):
        def get_text_embeddings(raw_texts, tokenizer, text_encoder, batch_size=1000, device='cuda'):
            all_embeddings = []
        
            for i in tqdm(range(0, len(raw_texts), batch_size)):
                batch_texts = raw_texts[i:i+batch_size]
                tokens = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
                # tokens = {k: v.to(device) for k, v in tokens.items()}
                
                with torch.no_grad():
                    outputs = text_encoder(**tokens)
                    batch_emb = outputs.last_hidden_state.mean(dim=1)  # [batch_size, hidden_dim]
                all_embeddings.append(batch_emb)
            
            return torch.cat(all_embeddings, dim=0)  # [total_texts, hidden_dim]
            
        text_emb = get_text_embeddings(texts, tokenizer, text_encoder)
        torch.save(text_emb, text_emb_file)   
    else:
        text_emb = torch.load(text_emb_file)
    return text_emb

def concept_embedding(concepts):
    tokens = tokenizer(concepts, return_tensors='pt', padding=True, truncation=True, max_length=512)
    concept_emb = text_encoder(**tokens).last_hidden_state.mean(dim=1)
    return concept_emb

def graph_text_embedding(dataset, data, num_hop):
    text_emb = text_embedding(dataset, data.raw_texts)
    
    if num_hop == 0:
        graph_emb = text_emb
    else:
        graph_emb_file = f'processed_data/{dataset}_{num_hop}_hop_neighbor_emb.pt'
        
        if not os.path.exists(graph_emb_file):
            def get_l_hop_neighbors(node_idx, edge_index, num_hop):
                # Returns:
                # - subset: indices of nodes in the 2-hop neighborhood
                # - edge_index_sub: edges in the subgraph
                # - mapping: index of the input node within the subset
                # - edge_mask: mask of edges in the original graph that are in the subgraph
                subset, _, _, _ = k_hop_subgraph(
                    node_idx=node_idx, 
                    num_hops=num_hop, 
                    edge_index=edge_index, 
                    relabel_nodes=False
                )
                return subset
                
            text_map = []
            for node_idx in tqdm(range(data.x.size()[0])):
                neighbors = get_l_hop_neighbors(node_idx, data.edge_index, num_hop=num_hop)
                neighbors_text_emb = text_emb[neighbors].mean(0)
                text_map.append(neighbors_text_emb)
            graph_emb = torch.stack(text_map)
            torch.save(graph_emb, graph_emb_file)
        
        else:
            graph_emb = torch.load(graph_emb_file)
        graph_emb = graph_emb
    return graph_emb

def graph_embedding(dataset, data, model_type, device):
    if model_type == 'gcn':
        gnn = GNN(in_dim=384, hidden_dim=384, out_dim=384, model_type=model_type).to(device)
    else:
        gnn = GNN(in_dim=384, hidden_dim=64, out_dim=384, model_type=model_type).to(device)
    # model_path = f'../GraphTextAlign/{model_type}_pretrained.pth'
    model_path = f'../GraphTextAlign/pretrained/{model_type}_contrastive.pth'
    gnn.load_state_dict(torch.load(model_path))
    gnn.eval()
    x_emb = text_embedding(dataset, data.raw_texts).to(device)
    graph_emb = gnn(x_emb, data.edge_index_test)
    return graph_emb

def cosine_similarity(emb1, emb2, scaled = True):
    
    emb1_normal = F.normalize(emb1, dim=1)   # [M, D]
    emb2_normal = F.normalize(emb2, dim=1)   # [N, D]

    similarity = torch.matmul(emb1_normal, emb2_normal.T)  # [M, N]
    if scaled:
        similarity = (similarity + 1.0) / 2.0
    return similarity