import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, Twitch
from torch_geometric.data import Data
from torch_geometric.utils import stochastic_blockmodel_graph, to_undirected
from torch_sparse import SparseTensor

def to_sparse_tensor(edge_index, edge_feat, num_nodes):
    """ converts the edge_index into SparseTensor
    """
    num_edges = edge_index.size(1)

    (row, col), N, E = edge_index, num_nodes, num_edges
    perm = (col * N + row).argsort()
    row, col = row[perm], col[perm]

    value = edge_feat[perm]
    adj_t = SparseTensor(row=col, col=row, value=value,
                         sparse_sizes=(N, N), is_sorted=True)

    # Pre-process some important attributes.
    adj_t.storage.rowptr()
    adj_t.storage.csr2csc()

    return adj_t


def load_dataset(data_dir, name, ood_type):
    '''
    dataset_ind: in-distribution training dataset
    dataset_ood_tr: ood-distribution training dataset as ood exposure
    dataset_ood_te: a list of ood testing datasets or one ood testing dataset
    '''
    dataset_ind, dataset_ood_tr, dataset_ood_te = load_graph_dataset(data_dir, name, ood_type)
    return dataset_ind, dataset_ood_tr, dataset_ood_te


def create_sbm_dataset(data, ood_fraction=0.3, p_ii=1.5, p_ij=0.5):
    """
    Create structural changes for a random subset of nodes (which become OOD samples).
    Only modifies the connections of OOD nodes.
    
    Args:
        data: PyG Data object
        ood_fraction: Fraction of nodes to be marked as OOD (default: 0.3)
        p_ii: Within-class edge probability multiplier
        p_ij: Between-class edge probability multiplier
    
    Returns:
        dataset: PyG Data object with modified structure and OOD mask
    """
    n = data.num_nodes
    n_ood = int(n * ood_fraction)
    
    # Create random OOD mask
    ood_mask = torch.zeros(n, dtype=torch.bool)
    ood_indices = torch.randperm(n)[:n_ood]
    ood_mask[ood_indices] = True
    
    # Create SBM parameters
    d = data.edge_index.size(1) / data.num_nodes / (data.num_nodes - 1)
    num_blocks = int(data.y.max()) + 1
    p_ii, p_ij = p_ii * d, p_ij * d
    
    # Create new edges only for OOD nodes using SBM
    block_size = n_ood // num_blocks
    block_sizes = [block_size for _ in range(num_blocks-1)] + [block_size + n_ood % block_size]
    edge_probs = torch.ones((num_blocks, num_blocks)) * p_ij
    edge_probs[torch.arange(num_blocks), torch.arange(num_blocks)] = p_ii
    
    # Generate new edges for OOD nodes
    new_edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs)
    
    # Combine original edges (for non-OOD nodes) with new edges (for OOD nodes)
    # First, remove all edges connected to OOD nodes from original edge_index
    mask = ~(ood_mask[data.edge_index[0]] | ood_mask[data.edge_index[1]])
    remaining_edges = data.edge_index[:, mask]
    
    # Map new_edge_index to the actual OOD node indices
    new_edge_index = ood_indices[new_edge_index]
    
    # Combine edges
    edge_index = torch.cat([remaining_edges, new_edge_index], dim=1)
    edge_index = to_undirected(edge_index)  # Ensure the graph remains undirected
    
    # Create new dataset with modified structure
    dataset = Data(x=data.x, 
                  edge_index=edge_index, 
                  y=data.y,
                  train_mask=data.train_mask,
                  val_mask=data.val_mask,
                  test_mask=data.test_mask)
    
    # Add OOD mask to dataset
    dataset.ood_mask = ood_mask
    
    return dataset

def create_feat_noise_dataset(data, ood_fraction=0.3):
    """
    Create feature noise for a random subset of nodes (which become OOD samples).
    
    Args:
        data: PyG Data object
        ood_fraction: Fraction of nodes to be marked as OOD (default: 0.3)
    
    Returns:
        dataset: PyG Data object with modified features and OOD mask
    """
    n = data.num_nodes
    n_ood = int(n * ood_fraction)
    
    # Create random OOD mask
    ood_mask = torch.zeros(n, dtype=torch.bool)
    ood_indices = torch.randperm(n)[:n_ood]
    ood_mask[ood_indices] = True
    
    # Only modify features for OOD nodes
    x_new = data.x.clone()
    
    # For OOD nodes, create noisy features by mixing random node features
    idx = torch.randint(0, n, (n_ood, 2))  # Random pairs of nodes to mix
    weight = torch.rand(n_ood).unsqueeze(1)  # Random mixing weights
    x_new[ood_indices] = data.x[idx[:, 0]] * weight + data.x[idx[:, 1]] * (1 - weight)
    
    # Create new dataset with modified features
    dataset = Data(x=x_new, 
                  edge_index=data.edge_index, 
                  y=data.y,
                  train_mask=data.train_mask,
                  val_mask=data.val_mask,
                  test_mask=data.test_mask)
    
    # Add OOD mask to dataset
    dataset.ood_mask = ood_mask
    
    return dataset

def create_label_noise_dataset(data):

    y = data.y
    n = data.num_nodes
    idx = torch.randperm(n)[:int(n * 0.5)]
    y_new = y.clone()
    y_new[idx] = torch.randint(0, y.max(), (int(n * 0.5), ))

    dataset = Data(x=data.x, edge_index=data.edge_index, y=y_new)
    dataset.node_idx = torch.arange(n)

    return dataset




# follows gnnsafe data processing:
# label 1, 2 as test OOD 
# label 3 as train OOD 
# label > 3 as in-distribution data
def load_graph_dataset(data_dir, dataname="cora", ood_type="label"):
    """
    Load a graph dataset with train/val/test splits and OOD mask.
    
    Args:
        data_dir: Directory containing the dataset
        dataname: Name of the dataset (default: "cora")
        ood_type: Type of OOD data ("label", "structure", or "feature")
    
    Returns:
        dataset: A PyG Data object with train_mask, val_mask, test_mask, and ood_mask
    """
    transform = T.NormalizeFeatures()
   
    torch_dataset = Planetoid(root=f'{data_dir}Planetoid', split='public',
                            name=dataname, transform=transform)
    dataset = torch_dataset[0]

    if ood_type == "label":
        # Create OOD mask (True for labels <= 3)
        ood_mask = (dataset.y <= 3).squeeze()
        
        # Update train/val/test masks to only include in-distribution samples
        ind_mask = ~ood_mask
        dataset.train_mask = dataset.train_mask & ind_mask
        dataset.val_mask = dataset.val_mask & ind_mask
        dataset.test_mask = dataset.test_mask & ind_mask
        
        # Add OOD mask to dataset
        dataset.ood_mask = ood_mask
        
    elif ood_type == "structure":
        # Create structural changes for OOD nodes
        modified_dataset = create_sbm_dataset(dataset, ood_fraction=0.3)
        dataset = modified_dataset
        
    elif ood_type == "feature":
        # Create feature noise for OOD nodes
        modified_dataset = create_feat_noise_dataset(dataset, ood_fraction=0.3)
        dataset = modified_dataset
        
    else:
        raise NotImplementedError(f"OOD type {ood_type} not implemented")

    return dataset