import torch
import math

from seism.motif import Motif
from seism.sequences import OneHotEncodedSequences
from seism.definitions import TestCase, create_mask_0_1

################################################################################
# ActivationVector Class

class ActivationVector(torch.Tensor):
    # ActivationVectors are tensors with shape (1, nb_motifs, nb_samples, 1)
    # ActivationVectors are centered along the nb_samples dimension
    def __new__(cls, tensor: torch.Tensor):
        assert isinstance(tensor, torch.Tensor), ('ActivationVectors can only',
            'be created from torch.Tensor')
        return super().__new__(cls, tensor)

    def __init__(self, tensor: torch.Tensor):
        super().__init__()
        assert len(self.size())==4, ('ActivationVectors should have 4',
            'dimensions: 1, nb_motifs, nb_samples, 1')
        assert self.size()[0]==1, 'First dimension of ActivationVector must be 1'
        assert self.size()[3]==1, 'Last dimension of ActivationVector must be 1'

    def __deepcopy__(self, memo):
        return ActivationVector(super().clone().detach())

    def get_number_of_samples(self):
        return self.size()[2]
    
    def get_number_of_motifs(self):
        return self.size()[1]

class ActivationFunction:
    # The activation function takes a set of motifs and a set of sequences and
    # returns the corresponding activation vectors
    def __init__(self, name: str):
        assert name in ['gaussian_max_pooling'], ('this activation function is'
            ,'not implemented yet')
        self.name = name

    """for gaussian max pooling activation function"""
    def _compute_sigma_gaussian_activation(self, motifs: Motif) -> torch.Tensor:
        """
        Given a Motif, computes the corresponding sigma in the gaussian kernel
        """
        motifs_length = motifs.get_length()
        return torch.as_tensor((0.9*motifs_length) ** (1/2) /2)

    def _compute_max_pooling_gaussian_activation_vector(self, motifs: Motif,
            sequences: OneHotEncodedSequences) -> torch.Tensor:
        """
        Compute the max pooling gaussian activation vector (without reverse
        complements) for each motif and each sequence in the batch
        """
        # k(u,z) = exp^(-d(u,z)^2/2sigma^2) 
        #        = exp^( - (||z||^2 + ||u||^2 -2 *<u,z>) /2sigma^2)
        #
        # as ||u|| is constant, argmax_u k(u,z) = argmax_u <u,z>

        mask = create_mask_0_1(sequences.get_original_lengths(),
                sequences.get_length()) # nb_samples, length
        mask = mask[:, motifs.get_length()-1:] # nb_samples, nb_kmers in seq
        mask = mask.unsqueeze(1).repeat(1, motifs.get_number(), 1) # nb_samples, nb_motifs, nb_kmers in seq

        dot_product = torch.nn.functional.conv1d(torch.Tensor(sequences),
                motifs) # nb_samples, nb_motifs, nb_kmers in seq
        dot_product[mask.logical_not()] = -math.inf # masked dot product
        max_pooled_dot_product = torch.max(dot_product, dim=2).values #nb_samples, nb_motifs
        
        # ||z-u||^2 = ||z||^2 + ||u||^2 -2 * <u,z>, and ||u||^2 = motif_length
        motifs_norm_squared =  (torch.norm(motifs, dim = (1,2))**2)
        motifs_norm_squared = motifs_norm_squared.unsqueeze(0).repeat(dot_product.size()[0], 1)
        motifs_length = motifs.get_length()
        euclidean_distance = motifs_norm_squared + motifs_length  -2 * max_pooled_dot_product

        sigma = self._compute_sigma_gaussian_activation(motifs)
        _invert_square_sigma = 1. / (2*sigma ** 2)
        activation_tensor = torch.exp( - _invert_square_sigma * euclidean_distance)
        return activation_tensor.t().unsqueeze(0).unsqueeze(-1)

    def _compute_max_pooling_gaussian_activation_vector_including_rc(self,
            motifs: Motif, sequences: OneHotEncodedSequences) -> ActivationVector:
        """
        Computes the max pooling gaussian activation vector for the motifs and sequences in the batch
        The activation vector is the max from the motif and its reverse complement
        """
        motifs_rc = motifs.get_reverse_complement()

        activation_tensor_no_rc = self._compute_max_pooling_gaussian_activation_vector(motifs, sequences)
        activation_tensor_rc = self._compute_max_pooling_gaussian_activation_vector(motifs_rc, sequences)

        activation_tensor = torch.max(torch.cat((activation_tensor_no_rc,
            activation_tensor_rc), dim=3), dim=3).values.unsqueeze(-1)

        activation_tensor = activation_tensor - torch.mean(activation_tensor,
                dim=2).unsqueeze(dim=2).repeat(1,1,sequences.get_number(), 1)
        return ActivationVector(activation_tensor)
    """end of gaussian max pooling activation function"""
    
    def __call__(self, motifs: Motif, sequences: OneHotEncodedSequences) -> ActivationVector:
        if self.name == 'gaussian_max_pooling':
            return self._compute_max_pooling_gaussian_activation_vector_including_rc(motifs,
                sequences)

################################################################################
# Tests
class TestActivationVector(TestCase):
    def test_construction(self):
        with self.assertRaises(AssertionError):
            ActivationVector([0,1,2])
    def test_dimensions(self):
        with self.assertRaises(AssertionError):
            ActivationVector(torch.zeros(3,4,5))
    def test_format(self):
        with self.assertRaises(AssertionError):
            ActivationVector(torch.zeros(3,2,3,2))
    def test_get_number_of_samples(self):
        nb_samples = 10
        self.assertEqual(ActivationVector(torch.zeros(1,1, nb_samples,
            1)).get_number_of_samples(), nb_samples)

class TestActivationFunction(TestCase):
    def test_construction(self):
        with self.assertRaises(AssertionError):
            ActivationFunction('pas implemente')

    def test_compute_sigma_gaussian_activation(self):
        activation_function = ActivationFunction('gaussian_max_pooling')
        length = 10
        m = Motif(torch.zeros(1,4,length))
        sigma = activation_function._compute_sigma_gaussian_activation(m)
        self.assertEqual(sigma, torch.ones(1)*1.5)

    def test_gaussian_max_pooling_activation_function(self):
        activation_function = ActivationFunction('gaussian_max_pooling')
        nb_motifs = 2
        nb_samples = 3
        m = Motif(torch.rand(nb_motifs,4,10))
        s = torch.zeros(nb_samples,4,20)
        s[:,0,:]=torch.ones(3,20)
        sequences = OneHotEncodedSequences(s)
        activation_vector = activation_function(m, sequences)
        self.assertEqual(nb_motifs, activation_vector.get_number_of_motifs())
        self.assert_torch_allclose(torch.mean(activation_vector, dim=2),
                torch.zeros(1, nb_motifs, 1))
        
