import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import pytest
import numpy as np
from sklearn.preprocessing import MinMaxScaler

from src.datasets.preprocessing.scaling import MinMaxTransformation  # adjust import path

def create_data_list(list_of_tensors):
    return [Data(x=torch.tensor(t, dtype=torch.float)) for t in list_of_tensors]

def test_minmax_transformation_basic():
    data_list = create_data_list([
        [[1.0], [2.0], [3.0]],
        [[0.0], [4.0]]
    ])
    loader = DataLoader(data_list, batch_size=1)
    transform = MinMaxTransformation(loader)

    test_data = Data(x=torch.tensor([[1.0], [2.5], [4.0]], dtype=torch.float))
    transformed = transform(test_data)

    expected_scaler = MinMaxScaler()
    expected_scaler.partial_fit(torch.tensor([[1.0], [2.0], [3.0]]))
    expected_scaler.partial_fit(torch.tensor([[0.0], [4.0]]))
    expected = torch.tensor(expected_scaler.transform([[1.0], [2.5], [4.0]]), dtype=torch.float)
    test = transformed.x
    assert torch.allclose(transformed.x, expected, atol=1e-5), f"Expected {expected}, got {transformed.x}"


def test_minmax_transformation_inplace_behavior():
    data_list = create_data_list([
        [[0.0], [10.0]]
    ])
    loader = DataLoader(data_list, batch_size=1)
    transform = MinMaxTransformation(loader)

    data = Data(x=torch.tensor([[5.0], [10.0]], dtype=torch.float))
    transformed = transform(data)

    assert transformed is data, "Transformation should modify the input Data object in place"
    assert transformed.x.min() >= 0.0 and transformed.x.max() <= 1.0


def test_minmax_transformation_multi_dimensional_features():
    data_list = create_data_list([
        [[0.0, 1.0], [2.0, 3.0]],
        [[4.0, 5.0]]
    ])
    loader = DataLoader(data_list, batch_size=1)
    transform = MinMaxTransformation(loader)

    data = Data(x=torch.tensor([[2.0, 1.0], [3.0, 5.0]], dtype=torch.float))
    transformed = transform(data)

    expected = torch.tensor(np.array([[0.5, 0.0],[0.75, 1.0]]), dtype=torch.float)

    assert torch.allclose(transformed.x, expected, atol=1e-5)


# def test_minmax_transformation_with_zero_range():
#     # All values are the same => min == max
#     data_list = create_data_list([
#         [[1.0], [1.0]]
#     ])
#     loader = DataLoader(data_list, batch_size=1)
#     transform = MinMaxTransformation(loader)
#
#     data = Data(x=torch.tensor([[1.0]], dtype=torch.float))
#     with pytest.raises(ValueError):
#         _ = transform(data)  # Will raise error due to 0 range division
#
#
# def test_minmax_transformation_invalid_dtype():
#     data_list = create_data_list([
#         [[1], [2]]
#     ])
#     loader = DataLoader(data_list, batch_size=1)
#     transform = MinMaxTransformation(loader)
#
#     data = Data(x=torch.tensor([[1], [2]], dtype=torch.int))  # not float
#     with pytest.raises(ValueError):
#         _ = transform(data)