import torch

from seism.definitions import TestCase

################################################################################
# OneHotEncodedSequences Class

class OneHotEncodedSequences(torch.Tensor):
    # A OneHotEncodedSequence is a one-hot encoded torch tensor with 3 dimensions: 
    # nb_sequences, dna_length (4), sequences_length
    # OHEsequences have a method get_number() that returns the number of sequences
    # OHEsequences have a method get_length() that returns the length of sequences
    def __new__(cls, tensor: torch.Tensor, original_lengths: torch.Tensor = None):
        assert isinstance(tensor, torch.Tensor), 'Sequences can only be created from torch.Tensor'
        if tensor.size()[0]>0:
            assert tensor_is_one_hot_encoded(tensor), 'Sequences can only be creatder from OHE tensors'
        return super().__new__(cls, tensor)
    def __init__(self, tensor: torch.Tensor, original_lengths: torch.Tensor = None):
        super().__init__()
        assert len(self.size())==3, 'Sequences should have 3 dimensions: nb_sequences, dna_length (4), sequences_length'
        assert self.size()[1]==4, 'Sequences dimension 1 should be equal to 4 (dna)'

        self.original_lengths = original_lengths
        if self.original_lengths is None:
            self.original_lengths = torch.ones(1).repeat(self.size()[0])*self.size()[2]
        assert len(self.original_lengths) == self.size()[0], 'Issue in the sequences length'

    def __deepcopy__(self, memo):
        return OneHotEncodedSequences(super().clone().detach())
            
    def get_number(self):
        return self.size()[0]

    def get_length(self):
        return self.size()[2]

    def get_original_lengths(self):
        return self.original_lengths

def tensor_is_one_hot_encoded(tensor: torch.Tensor) -> bool:
    """
    Checks is a tensor is OHE
    """
    if not tensor.sum(dim=1).max() == 1:
        return False
    elif not tensor.sum(dim=1).min() == 1:
        return False
    elif not tensor.unique().size()[0] == 2:
        return False
    elif not tensor.unique().sort().values[0] == 0:
        return False
    elif not tensor.unique().sort().values[1] == 1:
        return False
    else: 
        return True

################################################################################
#Tests

class TestSequences(TestCase):
    def test_construction(self):
        with self.assertRaises(AssertionError):
            OneHotEncodedSequences([0,1,2])

    def test_dimensions(self):
        with self.assertRaises(AssertionError):
            OneHotEncodedSequences(torch.zeros(1,1))
        with self.assertRaises(AssertionError):
            OneHotEncodedSequences(torch.zeros(1,5,4))

    def test_number(self):
        nb = 3
        tensor = torch.zeros(nb,4,3)
        tensor[:,0,:]=torch.ones(nb,3)
        s = OneHotEncodedSequences(tensor)
        self.assertEqual(s.get_number(), nb)
    
    def test_length(self):
        l = 4
        tensor = torch.zeros(2,4,l)
        tensor[:,0,:]=torch.ones(2,l)
        s = OneHotEncodedSequences(tensor)
        self.assertEqual(s.get_length(), l)

    def test_not_OHE(self):
        with self.assertRaises(AssertionError):
            OneHotEncodedSequences(torch.zeros(2,4,10))


    
