import torch
import pytest

from pina import LabelTensor

data = torch.rand((20, 3))
labels = ['a', 'b', 'c']


def test_constructor():
    LabelTensor(data, labels)


def test_wrong_constructor():
    with pytest.raises(ValueError):
        LabelTensor(data, ['a', 'b'])


def test_labels():
    tensor = LabelTensor(data, labels)
    assert isinstance(tensor, torch.Tensor)
    assert tensor.labels == labels
    with pytest.raises(ValueError):
        tensor.labels = labels[:-1]


def test_extract():
    label_to_extract = ['a', 'c']
    tensor = LabelTensor(data, labels)
    print(tensor)
    new = tensor.extract(label_to_extract)
    assert new.labels == label_to_extract
    assert new.shape[1] == len(label_to_extract)
    assert torch.all(torch.isclose(data[:, 0::2], new))


def test_extract_onelabel():
    label_to_extract = ['a']
    tensor = LabelTensor(data, labels)
    new = tensor.extract(label_to_extract)
    assert new.ndim == 2
    assert new.labels == label_to_extract
    assert new.shape[1] == len(label_to_extract)
    assert torch.all(torch.isclose(data[:, 0].reshape(-1, 1), new))


def test_wrong_extract():
    label_to_extract = ['a', 'cc']
    tensor = LabelTensor(data, labels)
    with pytest.raises(ValueError):
        tensor.extract(label_to_extract)


def test_extract_order():
    label_to_extract = ['c', 'a']
    tensor = LabelTensor(data, labels)
    new = tensor.extract(label_to_extract)
    expected = torch.cat(
        (data[:, 2].reshape(-1, 1), data[:, 0].reshape(-1, 1)),
        dim=1)
    print(expected)
    assert new.labels == label_to_extract
    assert new.shape[1] == len(label_to_extract)
    assert torch.all(torch.isclose(expected, new))


def test_merge():
    tensor = LabelTensor(data, labels)
    tensor_a = tensor.extract('a')
    tensor_b = tensor.extract('b')
    tensor_c = tensor.extract('c')

    tensor_bc = tensor_b.append(tensor_c)
    assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))

def test_merge2():
    tensor = LabelTensor(data, labels)
    tensor_b = tensor.extract('b')
    tensor_c = tensor.extract('c')

    tensor_bc = tensor_b.append(tensor_c)
    assert torch.allclose(tensor_bc, tensor.extract(['b', 'c']))


def test_getitem():
    tensor = LabelTensor(data, labels)
    tensor_view = tensor[:5]

    assert tensor_view.labels == labels
    assert torch.allclose(tensor_view, data[:5])

def test_slice():
    tensor = LabelTensor(data, labels)
    tensor_view = tensor[:5, :2]
    assert tensor_view.labels == labels[:2]
    assert torch.allclose(tensor_view, data[:5, :2])

    tensor_view2 = tensor[3]
    assert tensor_view2.labels == labels
    assert torch.allclose(tensor_view2, data[3])

    tensor_view3 = tensor[:, 2]
    assert tensor_view3.labels == labels[2]
    assert torch.allclose(tensor_view3, data[:, 2].reshape(-1, 1))