import torch

from seism.definitions import TestCase

################################################################################
# Labels Class

class Labels(torch.Tensor):
    # Labels are tensors with shape (nb_labels, 1, nb_samples, 1)
    # Labels are centered along the nb_samples dimension
    # Labels have a method get_number_of_samples() that returns the number of samples
    # Labels have a method get_number_of_labels() that returns the number of labels
    def __new__(cls, tensor: torch.Tensor):
        assert isinstance(tensor, torch.Tensor), 'Labels can only be created from torch.Tensor'
        return super().__new__(cls, tensor)

    def __init__(self, tensor: torch.Tensor):
        super().__init__()
        assert len(self.size())==4, 'Labels should have 4 dimensions: nb_labels, 1, nb_samples, 1'
        assert self.size()[1]==1, 'Second dimension of Labels must be 1'
        assert self.size()[3]==1, 'Last dimension of Labels must be 1'
    
    def __deepcopy__(self, memo):
        return Labels(super().clone().detach())

    def get_number_of_samples(self):
        return self.size()[2]

    def get_number_of_labels(self):
        return self.size()[0]

    def extract_one_label(self, ind: int):
        return Labels(self[ind,:,:,:].unsqueeze(0))

################################################################################
# Tests 

class TestLabels(TestCase):
    def test_construction(self):
        with self.assertRaises(AssertionError):
            Labels([0,1,2])
    def test_dimensions(self):
        with self.assertRaises(AssertionError):
            Labels(torch.zeros(3,4,5))
    def test_format(self):
        with self.assertRaises(AssertionError):
            Labels(torch.zeros(3,2,3,2))
    def test_get_number_of_samples(self):
        nb_samples = 10
        self.assertEqual(Labels(torch.zeros(1,1, nb_samples, 1)).get_number_of_samples(), nb_samples)
    def test_get_number_of_labels(self):
        nb_labels = 10
        self.assertEqual(Labels(torch.zeros(nb_labels,1, 1, 1)).get_number_of_labels(), nb_labels)
