import random
import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid, Flickr, HeterophilousGraphDataset
from torch_geometric.utils import k_hop_subgraph

def get_subgraphs(data, node_list, num_hops=2):
    graph_list = []
    for subgraph_idx, node in enumerate(node_list):
        subset, edge_index, mapping, _ = k_hop_subgraph(node_idx=node,
                                                        num_hops=num_hops,
                                                        edge_index=data.edge_index,
                                                        relabel_nodes=True)

        edge_index = edge_index.T.tolist()
        target_node = mapping.item()

        subset_list = subset.tolist()
        neighbor_mask = torch.ones(len(subset), dtype=torch.bool)
        neighbor_mask[target_node] = False

        adjusted_edge_index = []
        edge_weight = []
        trainable_edge = []
        
        # Construct Candidate Set (Target <-> Neighbors)
        for neighbor in range(len(subset)):
            if neighbor != target_node:
                # Add bidirectional candidate edges
                adjusted_edge_index.extend([[target_node, neighbor], [neighbor, target_node]])
                
                # If edge exists in original graph, init weight 1, else 0 (handled by optimizer later)
                # Here we just mark them as trainable. 
                # DiP-G initializes based on original structure usually, 
                # but since we learn structure, we set placeholder weights.
                if [target_node, neighbor] in edge_index:
                    edge_index.remove([target_node, neighbor])
                    if [neighbor, target_node] in edge_index:
                        edge_index.remove([neighbor, target_node])
                    edge_weight.extend([1.0, 1.0])
                else:
                    edge_weight.extend([0.0, 0.0])
                
                trainable_edge.extend([True, True])

        # Add remaining fixed edges (context)
        if len(edge_index) > 0:
            adjusted_edge_index.extend(edge_index)
            edge_weight.extend([1.0] * len(edge_index))
            trainable_edge.extend([False] * len(edge_index))

        subgraph_data = Data(x=data.x[subset],
                             edge_index=torch.tensor(adjusted_edge_index).T,
                             edge_weight=torch.tensor(edge_weight).float(),
                             y=data.y[node],
                             neighbor_mask=neighbor_mask,
                             target_node_index=torch.tensor(target_node),
                             trainable_edge=torch.tensor(trainable_edge, dtype=torch.bool),
                             original_idx=node,
                             node_idx=subset,
                             batch=torch.zeros(len(subset), dtype=torch.long) # Temp batch for single data
                             )

        graph_list.append(subgraph_data)
    return graph_list

def load_node_data(dataset_name, data_folder):
    if dataset_name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(root=f'{data_folder}/Planetoid', name=dataset_name)
    elif dataset_name in ['Amazon-ratings', 'Minesweeper']:
        dataset = HeterophilousGraphDataset(root=f'{data_folder}/HeterophilousGraphDataset', name=dataset_name)
    elif dataset_name == 'Flickr':
        dataset = Flickr(root=f'{data_folder}/Flickr')
    else:
        raise ValueError("Unknown dataset")

    return dataset[0], dataset.num_features, dataset.num_classes

def NodeDownstream(data, shots=5, test_node_num=1000):
    num_classes = data.y.max().item() + 1
    node_list = []
    # Few-shot sampling
    for c in range(num_classes):
        indices = (data.y.squeeze() == c).nonzero(as_tuple=True)[0].tolist()
        if len(indices) < shots:
            node_list.extend(indices)
        else:
            node_list.extend(random.sample(indices, k=shots))
            
    # Test set sampling
    all_indices = list(range(data.num_nodes))
    remain_indices = list(set(all_indices) - set(node_list))
    test_node_list = random.sample(remain_indices, k=min(len(remain_indices), test_node_num))

    return node_list, test_node_list