import torch
from torch_geometric.transforms import OneHotDegree
from torch_geometric.utils import degree
from torch_geometric.data import Dataset, Data
from torch_geometric.datasets import MNISTSuperpixels, TUDataset

class CustomDataset(Dataset):
    def __init__(self, data_list, num_features, num_classes):
        super().__init__()
        self.data_list = data_list
        self._num_features = num_features
        self._num_classes = num_classes
        
    def len(self):
        return len(self.data_list)
    
    def get(self, idx):
        return self.data_list[idx]
    
    @property
    def num_features(self):
        return self._num_features
    
    @property
    def num_classes(self):
        return self._num_classes

def process_tu_dataset(dataset):
    """
    Process TU dataset to ensure consistent feature dimensions
    Returns processed data list and number of features
    """
    
    if dataset.num_features > 0:
        # Dataset has node features
        processed_data_list = []
        for data in dataset:
            processed_data_list.append(data)
        num_features = dataset.num_features
    else:
        # Create degree-based features
        max_degree = 0
        for data in dataset:
            deg = degree(data.edge_index[0], data.num_nodes)
            max_degree = max(max_degree, int(deg.max()))
        
        transform = OneHotDegree(max_degree)
        processed_data_list = []
        
        for data in dataset:
            new_data = transform(data.clone())
            processed_data_list.append(new_data)
        
        num_features = max_degree + 1
        
    return processed_data_list, num_features

def get_dataset(name, root, transform, k, use_node_attr=False):
    """Helper function to load different datasets"""
    if name.lower() == 'mnist':
        train_dataset = MNISTSuperpixels(
            root=root,
            train=True,
            transform=transform
        )
        test_dataset = MNISTSuperpixels(
            root=root,
            train=False,
            transform=transform
        )
        return train_dataset, test_dataset
    
    else:  # TU Dataset
        print(f"\nLoading {name} dataset...")
        dataset = TUDataset(
            root=root,
            name=name,
            transform=None,  # We'll apply transform after processing
            use_node_attr=use_node_attr,
            cleaned=True
        )
        
        # Process the dataset
        processed_data_list, num_features = process_tu_dataset(dataset)
        print(f"Dataset processed with {num_features} features")
        
        # Apply transform if provided
        if transform is not None:
            processed_data_list = [transform(data) for data in processed_data_list]
            # Print example dimensions
            example_data = processed_data_list[0]

        
        # Split into train/test
        torch.manual_seed(12345)
        dataset_size = len(processed_data_list)
        indices = torch.randperm(dataset_size)
        train_size = int(0.8 * dataset_size)
        
        train_indices = indices[:train_size]
        test_indices = indices[train_size:]
        
        train_data_list = [processed_data_list[i] for i in train_indices]
        test_data_list = [processed_data_list[i] for i in test_indices]
        
        train_dataset = CustomDataset(
            train_data_list,
            num_features=num_features,
            num_classes=dataset.num_classes
        )
        test_dataset = CustomDataset(
            test_data_list,
            num_features=num_features,
            num_classes=dataset.num_classes
        )
        
        print(f"\nDataset preparation complete:")
        print(f"Number of training graphs: {len(train_dataset)}")
        print(f"Number of test graphs: {len(test_dataset)}")
        print(f"Number of features: {num_features}")
        print(f"Number of classes: {dataset.num_classes}\n")
        
        return train_dataset, test_dataset