import torch
from typing import Tuple
import os
import shutil
import time

from seism.sequences import OneHotEncodedSequences
from seism.dataset import Dataset, parse_fasta_dataset, dataset_from_sequences_and_labels
from seism.association_score import AssociationScore
from seism.motif import Motif
from seism.labels import Labels
from seism.activation_function import ActivationFunction, ActivationVector
from seism.definitions import TestCase
################################################################################
# Selecter Class

class Selecter: 
    # Selects the n best motifs given a dataset, an activation function and an association score
    def __init__(self, dataset: Dataset, association_score: AssociationScore, 
        activation_function: ActivationFunction, kmers_activation_vector_tuple: Tuple[ActivationVector]=None, 
        max_iteration_steps = None):

        self.dataset = dataset
        self.association_score = association_score
        self.activation_function = activation_function
        if not (max_iteration_steps is None):
            self.max_iteration_steps = max_iteration_steps
        else:
            if association_score.name=='hsic':
                self.max_iteration_steps = 1000
            elif association_score.name=='ridge':
                self.max_iteration_steps = 1000
            else:
                raise NotImplementedError('this association score has no default maximum iteration step, please provide one')
        
        if self.association_score.name=='hsic': 
            self.ls_initial_alpha_j = 0.5
            self.ls_c = 0.5
            self.ls_tau = 0.5
        elif self.association_score.name=='ridge':
            self.ls_initial_alpha_j = 10
            self.ls_tau = 0.5
            self.ls_c = 0.5
        else:
            raise NotImplementedError('this association score has no default line search parameters')

        if kmers_activation_vector_tuple is None:
            self.kmers_activation_vector_tuple = ()
            for kmers in dataset.get_kmer_list():
                self.kmers_activation_vector_tuple += (activation_function(kmers, self.dataset.get_sequences()),)
        else:
            self.kmers_activation_vector_tuple = kmers_activation_vector_tuple
        assert len(self.kmers_activation_vector_tuple) == len(dataset.get_kmer_list()), (
            'The kmers activation vectors are not the ones from the kmer list')
        assert min([ self.kmers_activation_vector_tuple[i].get_number_of_motifs() == dataset.get_kmer_list()[i].get_number()
                for i in range (len(dataset.get_kmer_list()))]), (
                    'The kmers activation vectors are not the ones from the kmer list')

    def get_kmers_activation_vector_tuple(self):
        return self.kmers_activation_vector_tuple

    def select_n_motifs(self,number_of_motifs: int) -> Tuple[Motif]:
        """
        Given a number of motifs to select, the function iteratively selects the best motifs,
        updates the activation vectors, and repeats until the number of motifs is reached
        """
        
        selected_motifs=()
        previous_activation_vectors = ActivationVector(torch.empty(1, 0, 
            self.dataset.get_sequences().get_number(), 1))
        projection_matrix = self._compute_projection_matrix(previous_activation_vectors)
        for i in range(number_of_motifs):
            best_kmers = self._get_best_kmers(projection_matrix)
            best_motif = self._get_best_motifs(projection_matrix, best_kmers)
            selected_motifs += (best_motif,)
            previous_activation_vectors = torch.cat((previous_activation_vectors, 
                self.activation_function(best_motif, self.dataset.get_sequences())), dim=1)
            projection_matrix = self._compute_projection_matrix(previous_activation_vectors)
        return selected_motifs

    def _get_best_kmers(self, projection_matrix: torch.Tensor) -> Tuple[Motif]:
        """
        Given a projection matrix, find the best k-mer for each each length
        """
        best_kmers=()
        for j in range(len(self.dataset.get_kmer_list())):
            kmers_scores = self.association_score(self.kmers_activation_vector_tuple[j],  
                self.dataset.get_labels(), projection_matrix)
            if not kmers_scores.size()[0] ==1:
                raise ValueError('best kmer search is only available for 1 label at a time')
            best_kmer = self.dataset.get_kmer_list()[j].extract_one_motif(int(kmers_scores.max(dim=1).indices))
            best_kmers += (best_kmer,)
        return best_kmers

    def compute_score(self, projection_matrix: torch.Tensor, motif: Motif) -> torch.Tensor:
        return self.association_score(self.activation_function(motif, 
            self.dataset.get_sequences()), self.dataset.get_labels(), projection_matrix).view(-1)

    def _evaluate_and_get_gradient(self, projection_matrix: torch.Tensor, motif: Motif):
        """
        Given a projection matrix and a motif, compute the score of the motif and the gradient of the score
        with respect to the motif
        """
        motif.requires_grad_()
        motif_score = self.compute_score(projection_matrix, motif)
        motif_score.backward(retain_graph=True)
        gradient = motif.grad.clone().detach()
        motif.grad.data.zero_()
        motif.requires_grad = False
        return motif_score, gradient

    def _line_search(self, projection_matrix: torch.Tensor, motif: Motif, 
        gradient: torch.Tensor, current_score: float):
        """
        Given a motif, a gradient, and a current score, the function finds the motif that minimizes the
        score by taking a step in the direction of the gradient
        (https://en.wikipedia.org/wiki/Backtracking_line_search)
        """
        m = -torch.norm(gradient)**2
        t = -self.ls_c*m
        j_ls=0
        max_iter = 20
        alpha_j_ls = self.ls_initial_alpha_j
        while ((self.compute_score(projection_matrix, (Motif(motif+alpha_j_ls*gradient))) - current_score) < alpha_j_ls*t) and (j_ls<max_iter):
                j_ls+=1
                alpha_j_ls = self.ls_tau*alpha_j_ls
        converged = not(j_ls == max_iter)
        return Motif(motif+alpha_j_ls*gradient).clone().detach(), converged

    def _get_best_motifs(self, projection_matrix: torch.Tensor, best_kmers: Tuple[Motif]) -> Motif:
        """
        Given a projection matrix, find the best motif by iteratively evaluating the gradient and
        performing line search
        """
        best_motifs_with_scores=()
        for j in range(len(self.dataset.get_kmer_list())):
            best_motif = best_kmers[j]
            iteration = 0
            scores = []
            compteur_for_break=0
            while iteration < self.max_iteration_steps:
                iteration += 1
                if iteration==self.max_iteration_steps:
                    print("optimization was early stopped due to max number of iterations")
                score_iteration, gradient = self._evaluate_and_get_gradient(projection_matrix, best_motif)
                if score_iteration.isnan():
                    assert False
                scores += [float(score_iteration)]
                best_motif, ls_converged = self._line_search(projection_matrix, 
                    best_motif, gradient, score_iteration)
                if (iteration > 2) and (abs(1-scores[iteration-1]/scores[iteration-2]) < 1e-4):
                    compteur_for_break+=1
                else:
                    compteur_for_break=0
                if compteur_for_break>0:
                    break
            best_motifs_with_scores += ((best_motif, scores[-1]),)
        return max(best_motifs_with_scores, key=lambda x:x[1])[0]

    def _compute_projection_matrix(self, activation_vectors: ActivationVector):
        n = activation_vectors.get_number_of_samples()
        nb_motifs = activation_vectors.get_number_of_motifs()
        A = torch.ones(n, 1)
        activation_vectors_for_projection = activation_vectors.squeeze()
        if nb_motifs == 1:
            A = torch.cat((A, activation_vectors_for_projection.unsqueeze(1)), dim=1)
        elif nb_motifs > 1:
            A = torch.cat((A, activation_vectors_for_projection.t()), dim=1)
            
        return torch.eye(n) - torch.matmul(A, torch.matmul(torch.inverse(torch.matmul(A.t(), A)), A.t()))

###################################################################################################
# Test

class TestSelection(TestCase):
    def test_kmer_list(self): 
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10, 12))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        Selecter(dataset, hsic_score, gaussian_activation_function)

        kmers_activation_vector_tuple = (ActivationVector(torch.zeros(1,4,3,1)))
        with self.assertRaises(AssertionError):
            Selecter(dataset, hsic_score, gaussian_activation_function, kmers_activation_vector_tuple)

        kmers_activation_vector_tuple = (ActivationVector(torch.zeros(1,4,3,1)), 
            ActivationVector(torch.zeros(1,4,3,1)))
        with self.assertRaises(AssertionError):
            Selecter(dataset, hsic_score, gaussian_activation_function, kmers_activation_vector_tuple)

    def test_get_best_kmers_hsic(self):
        ohe_sequences_tensor = torch.zeros(3,4,3)
        ohe_sequences_tensor[0,0,:]=1
        ohe_sequences_tensor[1,1,:]=1
        ohe_sequences_tensor[2,0,0]=1
        ohe_sequences_tensor[2,1,1]=1
        ohe_sequences_tensor[2,2,2]=1
        sequences = OneHotEncodedSequences(ohe_sequences_tensor) # AAA, CCC, ACG
        
        labels_tensor = torch.zeros(1,1,3,1)
        labels_tensor[:,:,0,:] = -1
        labels_tensor[:,:,1,:] = -1
        labels_tensor[:,:,2,:] = 1
        labels = Labels(labels_tensor) #-2/3, -2/3, 4/3

        lengths = (2,3)
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        dataset = dataset_from_sequences_and_labels(sequences, labels, lengths)
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        best_kmers = selecter._get_best_kmers(torch.eye(labels.get_number_of_samples()))
        self.assertEqual(len(best_kmers), len(lengths))
        self.assertEqual(best_kmers[0].get_length(), lengths[0])
        self.assertEqual(best_kmers[1].get_length(), lengths[1])

    def test_compute_projection_matrix(self):
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10, 12))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        nb_samples = 10
        nb_motifs = 4
        rand_vectors = torch.rand(1,nb_motifs, nb_samples,1)
        activation_vectors = ActivationVector(rand_vectors - torch.mean(rand_vectors, 
            dim=2).unsqueeze(2).repeat(1, 1, nb_samples, 1))
        projection_matrix = selecter._compute_projection_matrix(activation_vectors)
        self.assertEqual(projection_matrix.size()[0], nb_samples)
        for i in range(nb_motifs):
            self.assert_torch_allclose(torch.matmul(projection_matrix,
                activation_vectors[:,i,:,:].squeeze()), torch.zeros(nb_samples), atol=1e-5)
        self.assert_torch_allclose(projection_matrix, torch.matmul(projection_matrix, 
            projection_matrix), atol=1e-4)

    def test_selecter_hsic(self):
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10, ))
        dataset = dataset_from_sequences_and_labels(dataset.get_sequences(), 
            dataset.get_labels(), (10,))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        nb_motifs = 4
        selected_motifs = selecter.select_n_motifs(nb_motifs)

        location = 'tmp_test_hsic/'
        name = 'selected_motif_'
        try:
            shutil.rmtree(location)
        except:
            pass
        os.mkdir(location)
        i=0
        for motif in selected_motifs:
            i+=1
            motif.draw_to_file(location + name +str(i))
        shutil.rmtree(location)

    def test_get_best_kmer_ridge(self):
        ohe_sequences_tensor = torch.zeros(3,4,3)
        ohe_sequences_tensor[0,0,:]=1
        ohe_sequences_tensor[1,1,:]=1
        ohe_sequences_tensor[2,0,0]=1
        ohe_sequences_tensor[2,1,1]=1
        ohe_sequences_tensor[2,2,2]=1
        sequences = OneHotEncodedSequences(ohe_sequences_tensor) # AAA, CCC, ACG
        
        labels_tensor = torch.zeros(1,1,3,1)
        labels_tensor[:,:,0,:] = -1
        labels_tensor[:,:,1,:] = -1
        labels_tensor[:,:,2,:] = 1
        labels = Labels(labels_tensor) #-2/3, -2/3, 4/3

        lengths = (2,3)
        ridge_score = AssociationScore('ridge', lmbda_ridge=1)
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        dataset = dataset_from_sequences_and_labels(sequences, labels, lengths)
        selecter = Selecter(dataset, ridge_score, gaussian_activation_function)
        best_kmers = selecter._get_best_kmers((torch.eye(labels.get_number_of_samples())))
        self.assertEqual(len(best_kmers), len(lengths))
        self.assertEqual(best_kmers[0].get_length(), lengths[0])
        self.assertEqual(best_kmers[1].get_length(), lengths[1])

    def test_selecter_ridge(self):
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10, ))
        dataset = dataset_from_sequences_and_labels(dataset.get_sequences(), 
            dataset.get_labels(), (10,))
        ridge_score = AssociationScore('ridge', lmbda_ridge=1000)
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, ridge_score, gaussian_activation_function)
        nb_motifs = 4
        selected_motifs = selecter.select_n_motifs(nb_motifs)

        location = 'tmp_test_ridge/'
        name = 'selected_motif_'
        try:
            shutil.rmtree(location)
        except:
            pass
        os.mkdir(location)
        i=0
        for motif in selected_motifs:
            i+=1
            motif.draw_to_file(location + name +str(i))
        shutil.rmtree(location)
