import torch
import math
from typing import List, Tuple

from seism.definitions import TestCase
################################################################################
# Motif Class

class Motif(torch.Tensor):
    # Motifs are tensors of shape (nb_motifs, 4 (dna_length), motif_length)
    # Motifs are created from a tensor of shape (nb_motifs, 4, motif_length)
    # Motifs have a method get_number() that returns the number of motifs
    # Motifs have a method get_length() that returns the length of motifs
    # Motifs have a method get_reverse_complement() that returns the reverse complement of the motif
    # Motifs have a method project_on_simplex() that projects the motifs on the simplex
    # Motifs have a method extract_one_motif() that returns a specific motif from the motifs 
    # Motifs have a method draw_to_file() to visualize them in a .png file

    def __new__(cls, tensor: torch.Tensor):
        assert isinstance(tensor, torch.Tensor), 'Motifs can only be created from torch.Tensor'
        return super().__new__(cls, tensor)
    def __init__(self, tensor: torch.Tensor):
        super().__init__()
        assert len(self.size())==3, 'Motifs should have 3 dimensions: nb_motifs, dna_length (4), motif_length'
        assert self.size()[1]==4, 'Motifs dimension 1 should be equal to 4 (dna)'

    def get_number(self) -> int:
        return self.size()[0]
    
    def get_length(self) -> int:
        return self.size()[2]
    
    def __deepcopy__(self, memo):
        return Motif(super().clone().detach())
    
    def get_reverse_complement(self):
        return Motif(self.flip([1,2]))

    def project_on_simplex(self, dim: int=1, radius: float = 1.0):
        assert radius > 0
        points = self.clone().detach()
        dim_size: int = points.size(dim)
        points_sorted, _ = points.sort(dim = dim, descending = True) # points_sorted[..., :, ...] = (x'_0 >= ... >= x'_{n-1})
        # The algorithm computes a delta to be applied to all components of point x.
        # For 0 <= i <= n-1, delta_candidate_i = dc_i = 1/i+1 (sum_{0 <= j <= i} x'_j - radius)
        # delta = dc_i with i = max{ 0 <= i <= n-1, dc_i < x'_i }
        # The strategy here is to compute all dc_i, and select the right one afterwards
        dc_numerators = points_sorted.cumsum(dim = dim) - radius # numerators[..., i, ...] = sum_{0 <= j <= i} x'_j - radius
        dc_denominators = torch.arange(1, dim_size + 1, dtype = points.dtype) # dim = (dim_size)
        delta_candidates = dc_numerators / dc_denominators.view(
            # expand denominator to (1..., dim_size, 1...) to force broadcast
            tuple(dim_size if d == dim else 1 for d in range(points.dim()))
        ) # delta_candidates[..., i, ....] = dc_i for point[..., :, ...]
        # It is easy to prove that if dc_i >= x'_i, then dc_{i+1} >= x'_{i+1}
        # Also, dc_0 = x'_0 - radius < x'_0
        # Thus { dc_i < x'_i }_i = [True, ..., True, False, ... False]
        # Thus max{ 0 <= i <= n-1, dc_i < x'_i } = count(True in { dc_i < x'_i }) - 1
        delta_indices = torch.sum(
            delta_candidates < points_sorted,
            dim = dim,
            keepdim = True
        ).long() - 1 # dim = (..., 1, ...)
        delta = delta_candidates.gather(
            dim = dim,
            index = delta_indices
        ) # dim = (..., 1, ...)
        points.add_(-delta)
        points.clamp_(min = 0)
        return points

    def extract_one_motif(self, ind: int):
        return Motif(self[ind].unsqueeze(0))

    def draw_to_file(self, filename: str) -> None:
        if (self.get_number() == 1) :
            projected_motif = self.project_on_simplex()
            bits = ppm_to_bits(projected_motif)
            draw_to_file(filename +".png", bits)
        else: 
            for i in range(self.get_number()):
                projected_motif = self.extract_one_motif(i).project_on_simplex()
                bits = ppm_to_bits(projected_motif)
                draw_to_file(filename +str(i) +".png", bits) 

################################################################################
# For graphical visualization of the motifs

def ppm_to_bits(motif: Motif) -> torch.Tensor:
    """
    Generate bits representation from PPM ; see https://en.wikipedia.org/wiki/Sequence_logo.
    motif: (length, 4)
    output: (trimmed_length, 4) ;
    """
    length = motif.get_length()
    if motif.get_number() > 1:
        raise ValueError('Only one motif can be tranformed to bit')
    m = motif[0].t()
    information_content = (
        math.log2(4) + torch.sum(
            # Sum position by position pos * log2(pos), ignoring pos = 0 values (-inf log)
            torch.where(
                m > 0.,
                m * m.log2(),
                m.new_tensor(0.)
            ),
            dim = 1, keepdim = True
        )
    ) # float(length, 1)
    bits = m * information_content # float(length, 4)
    return bits
        
PALETTES = {
    "dna": {
        "A": [0.0, 0.5, 0.0],
        "C": [0.0, 0.0, 1.0],
        "G": [1.0, .65, 0.0],
        "T": [1.0, 0.0, 0.0],
    },
}

def draw_to_file(
    filename: str, motif_bits: torch.Tensor,
    width_scale: float = 0.8
):
    assert width_scale > 0.
    motif_length, motif_channels = motif_bits.size()
    assert motif_channels == 4
    max_stack_height: float = math.log2(motif_channels)

    # Import plot libs lazily, as they are only needed here
    import matplotlib
    import matplotlib.font_manager
    import matplotlib.patches
    import matplotlib.pyplot
    import matplotlib.text
    import matplotlib.transforms
    matplotlib.use('Agg') # Set backend
    
    # Prepare ressources

    palette = PALETTES["dna"]


    font = matplotlib.font_manager.FontProperties(
        family = ["sans-serif"], weight = "bold"
    )

    # Do not hscale letters more than a "E" ; avoids 'I' becoming a filled rectangle.
    min_letter_width: float = matplotlib.text.TextPath(
        (0, 0), "E", size = 1, prop = font
    ).get_extents().width

    # Create figure
    figure, axes = matplotlib.pyplot.subplots(
        figsize = (motif_length * width_scale, max_stack_height + 1.)
    )
    axes.set_xticks(range(1, motif_length + 1))
    axes.set_xlim((0, motif_length + 1))
    axes.set_yticks(range(0, math.ceil(max_stack_height) + 1))
    axes.set_ylabel("bits", fontsize = 15)
    for d in ["top", "right", "bottom"]:
        axes.spines[d].set_visible(False)

    # Add letters
    for position, scores in enumerate(motif_bits):
        letter_and_score_increasing: List[Tuple[str, float]] = sorted(
            zip("ACGT", scores.tolist()),
            key = lambda letter_score: letter_score[1]
        )
        stack_height = 0.
        for letter, score in letter_and_score_increasing:
            assert score >= 0.
            if score == 0.:
                continue # Do not generate any path
            # Render letter to vector path
            letter_path = matplotlib.text.TextPath(
                (0, 0), letter, size = 1, prop = font
            )
            # Compute transform to put letter in right position
            letter_box = letter_path.get_extents()
            transform = matplotlib.transforms.Affine2D()
            transform.translate(
                - letter_box.xmin - letter_box.width / 2.,
                -letter_box.ymin
            ) # Letter in rect x = [-w/2,w/2], y = [0, h]
            transform.scale(
                0.95 / max(letter_box.width, min_letter_width),
                (0.95 / letter_box.height) * score
            ) # Scale to rect x = [-0.5, 0.5], y = [0, score] with 5% padding
            transform.translate(position + 1, stack_height) # Move to final position
            # Add to figure
            axes.add_artist(matplotlib.patches.PathPatch(
                letter_path,
                linewidth = 0,
                facecolor = palette[letter],
                transform = transform + axes.transData
            ))
            stack_height += score

    # Finalize
    figure.tight_layout()
    figure.savefig(filename)
    matplotlib.pyplot.close(figure)
        
###################################################################################################
#Tests
class TestMotif(TestCase):
    def test_construction(self):
        with self.assertRaises(AssertionError):
            Motif([0,1,2])

    def test_dimensions(self):
        with self.assertRaises(AssertionError):
            Motif(torch.zeros(1,1))
        with self.assertRaises(AssertionError):
            Motif(torch.zeros(1,5,4))

    def test_number(self):
        nb = 3
        m = Motif(torch.rand(nb,4,1))
        self.assertEqual(m.get_number(), nb)
    
    def test_length(self):
        l = 4
        m = Motif(torch.rand(1,4,l))
        self.assertEqual(m.get_length(), l)

    def test_reverse_complement(self):
        tensor = torch.zeros(1,4,4)
        tensor[0, 0, 0] = 1
        tensor[0, 3, 1] = 1
        tensor[0, 1, 2] = 1
        tensor[0, 2, 3] = 1
        motif = Motif(tensor) #ATCG

        tensor_rc = torch.zeros(1,4,4)
        tensor_rc[0, 1, 0] = 1
        tensor_rc[0, 2, 1] = 1
        tensor_rc[0, 0, 2] = 1
        tensor_rc[0, 3, 3] = 1
        motif_rc = Motif(tensor_rc) #CGAT
    
        self.assert_torch_equal(motif.get_reverse_complement(), motif_rc)

    def test_simplex_projection(self):
        motifs = Motif(torch.rand(2,4,5)-torch.ones(2,4,5)/2)
        projected_motifs = motifs.project_on_simplex()
        self.assertIsInstance(projected_motifs, Motif)
        self.assertTrue((projected_motifs>=0).all())
        self.assert_torch_allclose(projected_motifs.sum(dim = 1), torch.ones(2, 5, 
                                                    dtype = projected_motifs.dtype))

    def test_draw_logo(self):
        import os
        motif = Motif(torch.rand(2,4,10))
        motif.draw_to_file("test_logo")
        os.remove("test_logo0.png")
        os.remove("test_logo1.png")
