import torch

from seism.activation_function import ActivationVector
from seism.labels import Labels
from seism.definitions import TestCase

################################################################################
# AssociationScore Class

class AssociationScore:
    # Computes the association scores between a set of labels and a set of 
    # activation vectors (size = nb_labels * nb_motifs)
    def __init__(self, name: str, lmbda_ridge: float=-1.0):
        assert name in ['hsic', 'ridge'], 'this activation function is not implemented yet'
        self.name = name
        if self.name=="ridge":
            assert lmbda_ridge>=0, ('A strictly positive penalization coefficient must be',
                'given for the ridge regression score')
            self.lmbda_ridge = lmbda_ridge
    
    def __call__(self, activation_vector: ActivationVector, labels: Labels, 
                    projection_matrix: torch.Tensor) -> torch.Tensor:

        assert activation_vector.get_number_of_samples() == labels.get_number_of_samples(), (
            'Labels and activation vectors must correspond to the same number of samples')

        nb_samples = activation_vector.get_number_of_samples()
        nb_motifs = activation_vector.get_number_of_motifs()
        nb_labels = labels.get_number_of_labels()

        labels_for_score = torch.Tensor(labels.repeat(1, nb_motifs, 1, 1).reshape(nb_labels*nb_motifs, 
                                nb_samples, 1))
        activation_vector_for_score = torch.Tensor(activation_vector.repeat(nb_labels, 1, 1, 1).reshape(
                                nb_labels*nb_motifs, nb_samples, 1))
        projection_matrix_for_score = projection_matrix.unsqueeze(0).repeat(nb_labels*nb_motifs, 1, 1)

        normalisation_factor = torch.norm(torch.bmm(projection_matrix_for_score, labels_for_score), 
                                dim=1).squeeze().reshape(nb_labels, nb_motifs)

        if self.name == 'hsic':
            return torch.square(1/normalisation_factor * torch.bmm(labels_for_score.transpose(dim0=1, dim1=2), 
                torch.bmm(projection_matrix_for_score, activation_vector_for_score)).reshape(nb_labels, nb_motifs))

        elif self.name == 'ridge':
            inv_coeff = 1/((torch.norm(activation_vector_for_score, dim=(1,2))**2+
                                nb_samples*self.lmbda_ridge).reshape(nb_labels, nb_motifs))

            return 1/normalisation_factor**2 * inv_coeff * torch.square(torch.bmm(
                labels_for_score.transpose(dim0=1, dim1=2), torch.bmm(projection_matrix_for_score, 
                activation_vector_for_score)).reshape(nb_labels, nb_motifs))

################################################################################
# Tests

class TestAssociationScore(TestCase):
    def test_hsic_score_dimensions(self):
        labels = Labels(torch.rand(5, 1, 4, 1))
        activation_vector = ActivationVector(torch.rand(1, 3, 8, 1))
        hsic_score = AssociationScore('hsic')
        with self.assertRaises(AssertionError):
            hsic_score(activation_vector, labels, torch.eye(8))

        activation_vector = ActivationVector(torch.zeros(1, 3, 4, 1))
        score = hsic_score(activation_vector, labels, torch.eye(4))
        self.assertEqual(score.dim(),2)
        self.assertEqual(score.size()[0],5)
        self.assertEqual(score.size()[1],3)
    def test_ridge_definition(self):
        with self.assertRaises(AssertionError):
            AssociationScore('ridge')
        with self.assertRaises(AssertionError):
            AssociationScore('ridge', -3)
    def test_ridge_score(self):
        labels = Labels(torch.ones(5, 1, 8, 1))
        activation_vector = ActivationVector(torch.rand(1, 3, 8, 1))
        ridge_score = AssociationScore('ridge', lmbda_ridge=10)
        score = ridge_score(activation_vector, labels, torch.eye(8))
        self.assertEqual(score.dim(),2)
        self.assertEqual(score.size()[0],5)
        self.assertEqual(score.size()[1],3)

    

