import os
import torch
import pytest
from torch_geometric.data import Data
from src.datasets.amazon_dataset import AmazonImplemented


class DummyAmazon:
    """Lightweight dummy replacement for torch_geometric.datasets.Amazon."""
    def __init__(self, root, name, transform=None, pre_transform=None, force_reload=True):
        self.root = root
        self.name = name
        self.data = Data()
        self.processed_paths = [os.path.join(root, "processed.pt")]

    def save(self, data_list, path):
        torch.save(data_list, path)

    def get(self, idx):
        return self.data


@pytest.fixture
def dummy_dataset(tmp_path, monkeypatch):
    """Create a minimal AmazonImplemented dataset with fake labels, no downloads."""
    # Monkeypatch Amazon with DummyAmazon
    import src.datasets.amazon_dataset as module_under_test
    monkeypatch.setattr(module_under_test, "Amazon", DummyAmazon)

    num_samples = 100
    num_classes = 5
    root = tmp_path / "amazon_dummy"
    os.makedirs(root, exist_ok=True)

    ds = module_under_test.AmazonImplemented(root=str(root), name="computers", force_reload=True)
    ds.data = Data(
        x=torch.randn(num_samples, 8),
        y=torch.arange(num_samples) % num_classes
    )
    ds.processed_paths = [str(root / "processed.pt")]
    return ds


def test_split_data_no_leakage(dummy_dataset):
    """Ensure train/val/test splits are disjoint and cover all samples."""
    ds = dummy_dataset
    ds.split_data(n_folds=0)

    train_mask = ds.data.train_mask
    val_mask = ds.data.val_mask
    test_mask = ds.data.test_mask

    # Masks exist
    assert train_mask is not None
    assert val_mask is not None
    assert test_mask is not None

    # No overlap between masks
    assert not torch.any(train_mask & val_mask), "Train/Val overlap detected"
    assert not torch.any(train_mask & test_mask), "Train/Test overlap detected"
    assert not torch.any(val_mask & test_mask), "Val/Test overlap detected"

    # All samples assigned to exactly one split
    total_assigned = train_mask | val_mask | test_mask
    assert torch.all(total_assigned), "Some samples are missing from all splits"


def test_testset_persistence(dummy_dataset):
    """Ensure test set is persistent across multiple split_data calls."""
    ds = dummy_dataset
    ds.split_data(n_folds=0)
    first_test_mask = ds.data.test_mask.clone()

    # Second call should load the same mask from disk
    ds2 = dummy_dataset
    ds2.split_data(n_folds=0)
    second_test_mask = ds2.data.test_mask

    assert torch.equal(first_test_mask, second_test_mask), "Test set was resampled"


def test_crossvalidation_masks(dummy_dataset):
    """Check that multiple folds create correct train/val mask tensors."""
    ds = dummy_dataset
    n_folds = 3
    ds.split_data(n_folds=n_folds)

    assert hasattr(ds.data, "train_masks"), "train_masks not created"
    assert hasattr(ds.data, "val_masks"), "val_masks not created"

    num_nodes = len(ds.data.y)
    assert ds.data.train_masks.shape == (num_nodes, n_folds)
    assert ds.data.val_masks.shape == (num_nodes, n_folds)

    # Ensure folds differ
    diffs = [torch.sum(ds.data.train_masks[:, 0] != ds.data.train_masks[:, j]) for j in range(1, n_folds)]
    assert any(d > 0 for d in diffs), "All folds have identical train masks"


def test_saves_masks_to_disk(dummy_dataset):
    """Verify that the mask file is actually saved to disk."""
    ds = dummy_dataset
    ds.split_data(n_folds=0)
    mask_file = os.path.join(ds.root, "masks", f"set_masks_seed{ds.seed}.pt")
    assert os.path.exists(mask_file), "Mask file not saved"
