from torch.utils.data import Dataset
import numpy as np
import random
import torch
import sys 
from torch_geometric.utils import k_hop_subgraph
sys.path.append("../..")
from common import prepare_edge_list
from common import CLASSES as classes, MATCHING_TEMPLATES, GraphGPT_DESC as CLASSIFICATION_TEMPLATES


class TextGraphGroundDataset(Dataset):
    def __init__(self, graph_data, num_sampled_neighbors):
        self.graph_data = graph_data 
        self.num_nodes = graph_data.num_nodes 
        
        self.edge_dict = prepare_edge_list(graph_data.edge_index.detach().cpu(), self.num_nodes)
        self.num_sampled_neighbors = num_sampled_neighbors
        
    def __len__(self):
        return self.num_nodes
    
    def __getitem__(self, idx):
        sampled_neighbors = [np.random.choice(self.edge_dict[idx], replace=True) for _ in range(self.num_sampled_neighbors)]
        neighbor_texts = [self.graph_data.raw_texts[neigh_id] for neigh_id in sampled_neighbors]
        return {
            "id": idx, 
            "text": self.graph_data.raw_texts[idx],
            "neighbor_text": neighbor_texts, 
            "neighbor_ids": np.array(sampled_neighbors)
        }


def fetch_title(txt, max_length=512):
    title = None
    if ":" in txt:
        title= txt.split(":")[0]
    title= txt.split(".")[0]
    
    return title[:max_length]


# Example Data: https://huggingface.co/datasets/Jiabin99/graph-matching
class GraphMatchingDataset(Dataset):
    def __init__(self, graph_data, k_hop=1, num_sampled_neighbors=8, graph_type="academic_network", sample_times=1, re_split=0):
        self.graph_data = graph_data 
        self.num_nodes = graph_data.num_nodes 
        self.k_hop = k_hop
        self.num_sampled_neighbors = num_sampled_neighbors
        self.sample_times = sample_times
        self.query_template = MATCHING_TEMPLATES[graph_type]
        self.graph_type = graph_type
        self.re_split = re_split
        self.is_inductive = (self.re_split == 2) or (self.re_split == 0 and self.num_nodes > 100000) # TODO: fix for arXiv
        
        self.all_data = self._prepare_matching_data()
    
    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, idx):
        return self.all_data[idx]
        
    def _prepare_matching_data(self):
        data_samples = []
        
        # For inductive learning, only use training nodes
        if self.is_inductive:
            focus_nodes = self.graph_data.train_mask.nonzero(as_tuple=False).squeeze().detach().cpu().numpy().tolist()
            # Create mapping from original node IDs to training embedding indices
            train_node_ids = self.graph_data.train_mask.nonzero(as_tuple=False).squeeze().detach().cpu().numpy()
            node_id_to_index = {node_id.item(): idx for idx, node_id in enumerate(train_node_ids)}
        else:
            focus_nodes = range(self.num_nodes)
            node_id_to_index = None
            
        for node in focus_nodes:
            neighbors, _, _, _ = k_hop_subgraph(node, num_hops=self.k_hop, edge_index=self.graph_data.edge_index)
            neighbors = neighbors.cpu().numpy()
            if len(neighbors.tolist()) == 0:
                continue
            
            for _ in range(self.sample_times):
                subset = np.random.choice(neighbors, size=self.num_sampled_neighbors).tolist()
                subset = list(set(subset))
                if node not in subset:
                    subset = [node] + subset[:-1] 
                else:
                    target_idx = subset.index(node)
                    subset[target_idx] = subset[0]
                    subset[0] = node
                
                if len(subset) < self.num_sampled_neighbors:
                    pad_length = self.num_sampled_neighbors - len(subset)
                    subset = subset + [node] * pad_length
                    
                assert subset[0] == node 
                
                # For inductive learning, map node IDs to embedding indices
                if self.is_inductive:
                    mapped_subset = []
                    for node_id in subset:
                        if node_id in node_id_to_index:
                            mapped_subset.append(node_id_to_index[node_id])
                        else:
                            # Fallback: use the current node's index
                            mapped_subset.append(node_id_to_index[node])
                    subset = mapped_subset
                
                texts = []
                for token_id in range(len(subset)): 
                    # For text retrieval, use original node IDs
                    if self.is_inductive:
                        # Map back to original ID for text retrieval
                        original_node_id = train_node_ids[subset[token_id]] if subset[token_id] < len(train_node_ids) else node
                    else:
                        original_node_id = subset[token_id]
                    
                    raw_text = fetch_title(self.graph_data.raw_texts[original_node_id])
                    texts.append([token_id, raw_text]) # origin_id in graph-tokens, corresponding text
                
                # Re-order the texts 
                random.shuffle(texts)
                tokenid2text_mapping = {pairs[0]+1: pairs[1] for text_id, pairs in enumerate(texts)}
                query_graph_texts = ". ".join([f"{text_id+1}. {pairs[1]}" for text_id, pairs in enumerate(texts)])
                
                if self.graph_type == "academic_network":
                    cur_query = self.query_template.replace("{{paper_titles}}", query_graph_texts)
                    cur_response = ". ".join([f"Graph token {k} corresponds to paper {tokenid2text_mapping[k]}" for k in sorted(tokenid2text_mapping.keys()) ])
                    cur_response = "Based on the given graph tokens and the list of paper titles, we obtain the matching of graph tokens and papers as follows: " + cur_response
                elif self.graph_type == "social_network":
                    cur_query = self.query_template.replace("{{user_profiles}}", query_graph_texts)
                    cur_response = ". ".join([f"Graph token {k} corresponds to user {tokenid2text_mapping[k]}" for k in sorted(tokenid2text_mapping.keys()) ])
                    cur_response = "Based on the given graph tokens and the descriptions of users, we obtain the matching of graph tokens and users as follows: " + cur_response
                elif self.graph_type == "ecommerce_network":
                    cur_query = self.query_template.replace("{{item_comments}}", query_graph_texts)
                    cur_response = ". ".join([f"Graph token {k} corresponds to item {tokenid2text_mapping[k]}" for k in sorted(tokenid2text_mapping.keys())])
                    cur_response = "Based on the given graph tokens and the comments of items, we obtain the matching of graph tokens and items as follows: " + cur_response

                sample = {
                    "id": node,
                    "nodes": torch.LongTensor(subset),
                    "query": cur_query,
                    "label": cur_response
                }
                data_samples.append(sample)
                
        return data_samples


# Example Data: https://huggingface.co/datasets/Jiabin99/Arxiv-PubMed-mix-NC-LP
class GraphInstructionTuningDataset(Dataset):
    def __init__(self, graph_data, k_hop=1, maximum_neighbors=4, dataset_name="cora", data_type="train", re_split=0, split_data=None):
        self.graph_data = graph_data  # This is the full graph for transductive or full_graph_data for inductive
        self.split_data = split_data  # This is train_data/val_data/test_data for inductive
        self.num_nodes = graph_data.num_nodes
        self.k_hop = k_hop 
        self.maximum_neighbors = maximum_neighbors
        self.label_names = classes[dataset_name]
        self.data_type = data_type
        self.re_split = re_split
        # Fix: Add inductive judgment for arxiv special case
        self.is_inductive = (re_split == 2) or (re_split == 0 and dataset_name == "arxiv")
        
        label_names = ", ".join(classes[dataset_name])
        self.query_prompt = CLASSIFICATION_TEMPLATES[dataset_name].replace("{{label_names}}", label_names)
        self.data_list = self.format_data()
        
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, index):
        return self.data_list[index]
    
    def format_data(self):
        if self.is_inductive and self.split_data is not None:  # Inductive setting with split_data
            # Use the split_data for the current phase
            current_graph = self.split_data
            if self.data_type == "train":
                # Train phase: all nodes in train_data are training nodes
                focus_nodes = list(range(current_graph.num_nodes))
            elif self.data_type == "val":
                # Val phase: find val nodes in val_data using node_ids mapping
                train_mask = self.graph_data.train_mask
                train_original_ids = set(train_mask.nonzero(as_tuple=False).squeeze().detach().cpu().numpy().tolist())
                val_data_node_ids = current_graph.node_ids.cpu().numpy().tolist()
                focus_nodes = [i for i, orig_id in enumerate(val_data_node_ids) if orig_id not in train_original_ids]
            else:  # test
                # Test phase: find test nodes in test_data
                test_mask = self.graph_data.test_mask
                test_original_ids = set(test_mask.nonzero(as_tuple=False).squeeze().detach().cpu().numpy().tolist())
                test_data_node_ids = current_graph.node_ids.cpu().numpy().tolist()
                focus_nodes = [i for i, orig_id in enumerate(test_data_node_ids) if orig_id in test_original_ids]
        else:
            # Transductive setting or fallback: use original logic
            current_graph = self.graph_data
            focus_mask = {"train": self.graph_data.train_mask, "val": self.graph_data.val_mask, "test": self.graph_data.test_mask}[self.data_type]
            focus_nodes = focus_mask.nonzero(as_tuple=False).squeeze().detach().cpu().numpy().tolist()
        
        available_data_list = []
        for cur_node in focus_nodes:
            # Check if the node is isolated (no edges)
            edge_mask = (current_graph.edge_index[0] == cur_node) | (current_graph.edge_index[1] == cur_node)
            has_edges = edge_mask.any().item()
            
            if not has_edges:
                # Isolated node - handle separately
                neighbors = [cur_node]  # Only itself as neighbor
            else:
                neighbors, _, _, _ = k_hop_subgraph(cur_node, num_hops=self.k_hop, edge_index=current_graph.edge_index)
                neighbors = neighbors.cpu().numpy().tolist()
        
            if len(neighbors) > self.maximum_neighbors:
                neighbors = np.random.choice(np.array(neighbors), size=self.maximum_neighbors).tolist()
                neighbors = [cur_node] + neighbors
            else: 
                pad_length = self.maximum_neighbors - len(neighbors) 
                neighbors = [cur_node] + neighbors + [cur_node] * pad_length
            
            assert cur_node == neighbors[0]
            
            # For inductive learning, node IDs are already correctly mapped in each phase's graph_data
            # No additional mapping needed as embeddings are handled by model.set_phase()
        
            cur_query = self.query_prompt.replace("{{raw_text}}", current_graph.raw_texts[cur_node])
            cur_response = self.label_names[current_graph.y[cur_node].item()]
        
            available_data_list.append({
                "id": cur_node, 
                "nodes": torch.LongTensor(neighbors),
                "query": cur_query,
                "label": cur_response
            })  
        
        return available_data_list   