import torch
from typing import Tuple
import os
import shutil

from seism.motif import Motif
from seism.labels import Labels
from seism.dataset import Dataset, parse_fasta_dataset
from seism.association_score import AssociationScore
from seism.activation_function import ActivationFunction, ActivationVector
from seism.selection import Selecter
from seism.parallel_search import SearchWorkers, Campaign
from seism.definitions import TestCase

################################################################################
# PSI Inferer Class

class Inferer:
    # The Inferer class is used to perform inference on a set of motifs.
    def __init__(self, inference_type: str):
        self.inference_type = inference_type
        assert self.inference_type in ['hit_and_run', 'rejection', 'data_split'], 'Unknown inference type'

    def _successive_projection_matrices(self, successive_motifs: Tuple[Motif], 
            selecter: Selecter) -> Tuple[torch.Tensor]:

        projection_matrices = ()
        for i in range(len(successive_motifs)):
            other_centers_activation_vectors = ActivationVector(torch.empty(1, 0, 
                selecter.dataset.get_labels().get_number_of_samples(), 1))
            for j in range(len(successive_motifs)):
                if not i==j:
                    other_centers_activation_vectors = torch.cat((other_centers_activation_vectors, 
                        selecter.activation_function(successive_motifs[j], selecter.dataset.get_sequences())), dim=1)
            assert other_centers_activation_vectors.size()[1] == (len(successive_motifs)-1), (
                'Issue in the computation of the activation vectors of the mesh centers ')
            projection_matrices += (selecter._compute_projection_matrix(other_centers_activation_vectors),)
        assert len(projection_matrices) == len(successive_motifs) 

        return projection_matrices

    def _inference_with_replicates(self, motifs_set: Tuple[Motif], selecter: Selecter, 
            location:str, mesh_size: float, nb_replicates: int, replicates: Labels, 
            selected_motifs_replicates: Tuple[Motif]):
        
        nb_motifs = len(motifs_set)
        dataset_replicates = Dataset(selecter.dataset.get_sequences(), replicates, selecter.dataset.kmer_list)
        selecter_replicates = Selecter(dataset_replicates, selecter.association_score, 
            selecter.activation_function, selecter.get_kmers_activation_vector_tuple())

        # For mesh centers
        mesh_centers = compute_mesh_centers(motifs_set, mesh_size)
        projection_matrices = self._successive_projection_matrices(mesh_centers, selecter)
        score_data_mesh_centers = ()
        scores_replicates_mesh_centers =()
        
        for i in range(nb_motifs):
            score_data_mesh_centers += (selecter.compute_score(projection_matrices[i], mesh_centers[i]),)
            scores_replicates_mesh_centers += (selecter_replicates.compute_score(projection_matrices[i], 
                mesh_centers[i]).view(-1),)
            assert scores_replicates_mesh_centers[i].size()[0] == nb_replicates, (
                'Number of replicates is different from number of scores of the replicates')

        pvalues_centers=()
        for i in range(nb_motifs):
            pvalues_centers += (compute_p_value(score_data_mesh_centers[i], scores_replicates_mesh_centers[i]),)

        torch.save(score_data_mesh_centers, location+'/score_data_mesh_centers.pt' )
        torch.save(scores_replicates_mesh_centers, location+'/score_replicates_mesh_centers.pt' )

        # For best motifs in mesh
        score_data_best_motifs_in_mesh=()
        scores_replicates_best_motifs_in_mesh=()

        for h in range(nb_motifs):

            scores_replicates_best_motifs_in_mesh_h=torch.empty(0)

            for i in range(replicates.get_number_of_labels()):
                dataset_replicate_i = Dataset(selecter.dataset.get_sequences(), replicates.extract_one_label(i),
                     selecter.dataset.kmer_list)

                selecter_replicate_i = Selecter(dataset_replicate_i, selecter.association_score, 
                    selecter.activation_function, selecter.get_kmers_activation_vector_tuple())

                selected_motifs_replicates_i = selected_motifs_replicates[i]

                other_motifs_activation_vectors = ActivationVector(torch.empty(1, 0, 
                    selecter.dataset.get_labels().get_number_of_samples(), 1))

                for j in range(nb_motifs):
                    if not h==j:
                        other_motifs_activation_vectors = torch.cat((other_motifs_activation_vectors, 
                            selecter.activation_function(selected_motifs_replicates_i[j], 
                            selecter.dataset.get_sequences())), dim=1)

                assert other_motifs_activation_vectors.size()[1] == (nb_motifs-1), (
                    'Issue in the computation of the activation vectors of the mesh centers ')

                projection_matrix = selecter._compute_projection_matrix(other_motifs_activation_vectors)

                scores_replicates_best_motifs_in_mesh_h = torch.cat((scores_replicates_best_motifs_in_mesh_h,
                    selecter_replicate_i.compute_score(projection_matrix, 
                    selected_motifs_replicates_i[h]).view(-1)), dim=0)

            scores_replicates_best_motifs_in_mesh += (scores_replicates_best_motifs_in_mesh_h,)     
        
        for h in range(nb_motifs):

            other_motifs_activation_vectors = ActivationVector(torch.empty(1, 0, 
                selecter.dataset.get_labels().get_number_of_samples(), 1))

            for j in range(nb_motifs):
                if not h==j:
                    other_motifs_activation_vectors = torch.cat((other_motifs_activation_vectors,
                        selecter.activation_function(motifs_set[j], selecter.dataset.get_sequences())), dim=1)
            assert other_motifs_activation_vectors.size()[1] == (nb_motifs-1), (
                'Issue in the computation of the activation vectors of the mesh centers ')

            projection_matrix = selecter._compute_projection_matrix(other_motifs_activation_vectors)

            score_data_best_motif_in_mesh_h = selecter.compute_score(projection_matrix, motifs_set[h])

            score_data_best_motifs_in_mesh += (score_data_best_motif_in_mesh_h,)

        torch.save(score_data_best_motifs_in_mesh, location+'/score_data_best_motifs_in_mesh.pt' )
        torch.save(scores_replicates_best_motifs_in_mesh, location+'/score_replicates_best_motifs_in_mesh.pt' )
        
        pvalues_best_motifs_in_mesh=()

        for i in range(nb_motifs):
            pvalues_best_motifs_in_mesh += (compute_p_value(score_data_best_motifs_in_mesh[i], 
                scores_replicates_best_motifs_in_mesh[i]),)

        return pvalues_centers, pvalues_best_motifs_in_mesh , replicates

    def _inference_using_hit_and_run(self, motifs_set: Tuple[Motif], selecter: Selecter, 
            model_noise: float, location:str, mesh_size: float, nb_burn_in: int, 
            nb_replicates: int) -> Tuple[float]:
        
        hit_and_run_sampler = HitAndRunSampler(nb_burn_in, nb_replicates, model_noise, location)

        replicates, selected_motifs_replicates = hit_and_run_sampler(motifs_set, mesh_size, selecter)

        return self._inference_with_replicates(motifs_set, selecter, location, mesh_size,
                 nb_replicates, replicates, selected_motifs_replicates)

    def _inference_using_rejection(self, motifs_set: Tuple[Motif], selecter: Selecter, 
            model_noise: float, location:str, mesh_size: float, nb_replicates: int):

        rejection_sampler = RejectionSampler(nb_replicates, model_noise, location)

        replicates, selected_motifs_replicates = rejection_sampler(motifs_set, mesh_size, selecter)

        return self._inference_with_replicates(motifs_set, selecter, location, mesh_size, 
            nb_replicates, replicates, selected_motifs_replicates)

    def _inference_using_data_split(self, test_selecter: Selecter, model_noise: float, 
            nb_replicates: int, motifs_set: Tuple[Motif]):

        nb_motifs = len(motifs_set)
        nb_samples_test = test_selecter.dataset.get_sequences().get_number()

        generator = torch.distributions.normal.Normal(torch.zeros(nb_samples_test), 
                        torch.ones(nb_samples_test)*model_noise)

        projection_matrices = self._successive_projection_matrices(motifs_set, test_selecter)

        replicates = Labels(torch.empty(0,1,nb_samples_test,1))

        for i in range(nb_replicates):
            replicates = torch.cat((replicates, Labels(generator.sample().unsqueeze(0).unsqueeze(0).unsqueeze(3))), dim=0)

        replicates = Labels(replicates) 
        dataset_replicates = Dataset(test_selecter.dataset.get_sequences(), replicates, 
                                        test_selecter.dataset.kmer_list)
        selecter_replicates = Selecter(dataset_replicates, test_selecter.association_score, 
                    test_selecter.activation_function, test_selecter.get_kmers_activation_vector_tuple())

        score_data = ()
        score_replicates = ()
        for i in range(nb_motifs):

            score_data += (test_selecter.compute_score(projection_matrices[i], motifs_set[i]),)

            score_replicates += (selecter_replicates.compute_score(projection_matrices[i], 
                                motifs_set[i]).view(-1),)

        pvalues = ()
        for i in range(nb_motifs):
            pvalues += (compute_p_value(score_data[i], score_replicates[i]),)
        return pvalues, replicates

    def __call__(self, **kwargs) -> Tuple[float]:

        if self.inference_type == 'hit_and_run':
            return self._inference_using_hit_and_run(**kwargs)
        if self.inference_type == 'rejection':
            return self._inference_using_rejection(**kwargs)
        if self.inference_type == 'data_split':
            return self._inference_using_data_split(**kwargs)

################################################################################
#SAMPLERS
# Rejection Sampler
class RejectionSampler:
    def __init__(self, number_of_replicates: int, model_noise: float, location: str):
        self.desired_number_of_replicates = number_of_replicates
        self.model_noise = model_noise
        self.location = location 
    
    def _accept_new_point(self, new_point: Labels, mesh_centers: Tuple[Motif], 
            selecter: Selecter, mesh_size: float) -> bool:
        """
        We want to check if the new mesh centers are the same as the target mesh centers
        """

        new_dataset = Dataset(selecter.dataset.one_hot_encoded_sequences, 
            new_point, selecter.dataset.kmer_list)

        new_selecter = Selecter(new_dataset, selecter.association_score, 
            selecter.activation_function, selecter.get_kmers_activation_vector_tuple())
        new_selected_motifs = new_selecter.select_n_motifs(len(mesh_centers))
        new_mesh_centers = compute_mesh_centers(new_selected_motifs, mesh_size)
        accept = compare_mesh_centers(new_mesh_centers, mesh_centers)

        return accept, new_selected_motifs
            
    @staticmethod
    def search_function(campaign: Campaign, self, mesh_centers: Tuple[Motif], 
            selecter: Selecter, generator: torch.distributions.normal.Normal,
            mesh_size: float, worker_nb_attempts: int):

        for attempt in range(worker_nb_attempts):
            if campaign.has_ended(): # Check for other worker's completion, main difference with non-parallel code
                return None
            else:
                new_point = Labels(generator.sample().view(1, 1, -1, 1))
                new_point_is_accepted, new_selected_motifs = self._accept_new_point(new_point, 
                    mesh_centers, selecter, mesh_size)
                if new_point_is_accepted:
                    return new_point, new_selected_motifs
        return None
    
    def __call__(self, motifs_set: Tuple[Motif], mesh_size: float, selecter: Selecter) -> Labels:
        """
        Given a set of motifs, a mesh size, a selecter, and a noise level, 
        it will generate labels under the null in the selection event
        """
        replicates = Labels(torch.empty(0, 1, selecter.dataset.get_sequences().get_number(), 1))
        ongoing_number_of_replicates = 0
        mesh_centers = compute_mesh_centers(motifs_set, mesh_size)
        generator = torch.distributions.normal.Normal(torch.zeros(selecter.dataset.get_labels().get_number_of_samples()), 
                                    torch.ones(selecter.dataset.get_labels().get_number_of_samples())*self.model_noise)

        selected_motifs_replicates = ()
        worker_nb_attempts = 50000
        with SearchWorkers(num_workers=torch.get_num_threads()) as workers:
            while ongoing_number_of_replicates < self.desired_number_of_replicates:
                torch.save(ongoing_number_of_replicates+1, self.location + '/current_state.pt')
                results = workers.campaign(RejectionSampler.search_function, 
                    self, mesh_centers, selecter, generator, mesh_size, worker_nb_attempts)
                if results is not None:
                    new_point, selected_motifs_replicates_i = results
                    replicates = torch.cat((replicates, new_point), dim=0)
                    selected_motifs_replicates += (selected_motifs_replicates_i,)
                    ongoing_number_of_replicates += 1
        return replicates, selected_motifs_replicates
                
# Hit & Run Sampler
class HitAndRunSampler:
    def __init__(self, number_of_burn_in_iterations: int, number_of_replicates: int, 
        model_noise: float, location: str):

        self.desired_number_of_burn_in_iterations = number_of_burn_in_iterations
        self.desired_number_of_replicates = number_of_replicates
        self.model_noise = model_noise
        self.location = location 

    def _draw_random_direction(self, generator):
        """
        Draw a random direction in the label space
        """
        dir = generator.sample()
        return dir/torch.norm(dir)

    def _draw_new_point(self, previous_point: Labels, direction: torch.Tensor, 
        generator: torch.distributions.normal.Normal) -> Labels:
        """
        Random sampling along a direction
        """
        formatted_previous_point = torch.Tensor(previous_point.squeeze())
        previous_z = generator.cdf(formatted_previous_point)
        theta = direction
        theta_pos = theta[theta>0]
        theta_neg = theta[theta<0]
        previous_z_theta_neg = previous_z[theta<0]
        previous_z_theta_pos = previous_z[theta>0]
        try:
            a = torch.max(torch.max(-previous_z_theta_pos/theta_pos), 
                    torch.max((1-previous_z_theta_neg)/theta_neg))
        except:
            a = 0
        try:
            b = torch.min(torch.min(-previous_z_theta_neg/theta_neg), 
                    torch.min((1-previous_z_theta_pos)/theta_pos))
        except:
            b = 1
        new_point = torch.zeros(1)*float('nan')
        assert a<b, 'issue in the HD sampling: b<=a'
        i=0
        while (torch.isnan(new_point).any() or torch.isinf(new_point).any()):
            i+=1
            lamb = torch.rand(size=())*(b-a)+a
            current_z = previous_z + lamb * theta
            new_point = generator.icdf(current_z)
        new_point = Labels(new_point.view(previous_point.size()))
        return new_point

    def _accept_new_point(self, new_point: Labels, mesh_centers: Tuple[Motif], 
        selecter: Selecter, mesh_size: float) -> bool:
        """
        We want to check if the new mesh centers are the same as the target mesh centers
        """

        new_dataset = Dataset(selecter.dataset.one_hot_encoded_sequences, 
                            new_point, selecter.dataset.kmer_list)

        new_selecter = Selecter(new_dataset, selecter.association_score, selecter.activation_function, 
                                selecter.get_kmers_activation_vector_tuple())

        new_selected_motifs = new_selecter.select_n_motifs(len(mesh_centers))
        new_mesh_centers = compute_mesh_centers(new_selected_motifs, mesh_size)
        accept = compare_mesh_centers(new_mesh_centers, mesh_centers)

        return accept, new_selected_motifs
            
    @staticmethod
    def search_function(campaign: Campaign, self, previous_point: Motif, 
        direction: torch.Tensor, mesh_centers: Tuple[Motif], selecter: Selecter, 
        generator: torch.distributions.normal.Normal, mesh_size: float, worker_nb_attempts: int):

        for attempt in range(worker_nb_attempts):
            if campaign.has_ended(): # Check for other worker's completion, main difference with non-parallel code
                return None
            else:
                new_point = self._draw_new_point(previous_point, direction, generator)
                new_point_is_accepted, new_selected_motifs = self._accept_new_point(new_point, 
                                                            mesh_centers, selecter, mesh_size)
                if new_point_is_accepted:
                    return new_point, new_selected_motifs
        return None

    def __call__(self, motifs_set: Tuple[Motif], mesh_size: float, selecter: Selecter) -> Labels:
        """
        Given a set of motifs, a mesh size, a selecter, and a noise level, 
        it will generate labels under the null in the selection event
        """
        replicates = Labels(torch.empty(0, 1, selecter.dataset.get_sequences().get_number(), 1))
        ongoing_number_of_burn_in = 0
        ongoing_number_of_replicates = 0
        mesh_centers = compute_mesh_centers(motifs_set, mesh_size)
        previous_point = selecter.dataset.get_labels()
        generator = torch.distributions.normal.Normal(torch.zeros(previous_point.get_number_of_samples()), 
                                        torch.ones(previous_point.get_number_of_samples())*self.model_noise)
        selected_motifs_replicates = ()
        worker_nb_attempts = 500
        with SearchWorkers(num_workers=torch.get_num_threads()) as workers:
            while ongoing_number_of_replicates < self.desired_number_of_replicates:
                if ongoing_number_of_burn_in < self.desired_number_of_burn_in_iterations:
                    printProgressBar(ongoing_number_of_burn_in, self.desired_number_of_burn_in_iterations, 
                            prefix="Progress burn-in iterations:", length=50)
                else:
                    printProgressBar(ongoing_number_of_replicates, self.desired_number_of_replicates, 
                            prefix="Progress replicates        :", length=50)
                torch.save((ongoing_number_of_burn_in +1, ongoing_number_of_replicates+1), 
                            self.location + '/current_state.pt')
                random_direction = self._draw_random_direction(generator)
                results = workers.campaign(HitAndRunSampler.search_function, self, previous_point, 
                            random_direction, mesh_centers, selecter, generator, mesh_size, worker_nb_attempts)
                if results is not None:
                    previous_point, selected_motifs_replicates_i =results
                
                    if ongoing_number_of_burn_in >= self.desired_number_of_burn_in_iterations:
                        ongoing_number_of_replicates += 1
                        replicates = torch.cat((replicates, previous_point), dim=0)
                        selected_motifs_replicates += (selected_motifs_replicates_i,)
                    else:
                        ongoing_number_of_burn_in += 1
        return replicates, selected_motifs_replicates

################################################################################
# Utils

def compute_mesh_centers(motifs_set: Tuple[Motif], mesh_size: float) -> Tuple[Motif]:
    """
    Given a set of motifs, the function returns a set of motifs that are the centers of the corresponding meshes.
    """
    mesh_centers = tuple()
    for i in range(len(motifs_set)):
        mesh_center_i = torch.trunc( torch.min(motifs_set[i], 
            torch.ones(motifs_set[i].size())-1e-5) / mesh_size)*mesh_size
        mesh_center_i += (1-mesh_center_i.sum(dim=1))/4
        mesh_centers += ((Motif(mesh_center_i)),)
    return mesh_centers

def compare_mesh_centers(mesh_centers_A: Tuple[Motif], mesh_centers_B: Tuple[Motif])-> bool:
    """
    Compare if the mesh centers are the same, conditionning on the order
    """
    assert len(mesh_centers_A)==len(mesh_centers_B), 'The same number of mesh centers must be given'
    accept = True
    for i in range(len(mesh_centers_A)):
        if not mesh_centers_A[i].equal(mesh_centers_B[i]):
            return False
        else:
            pass
    return accept

def compute_p_value(score_true: torch.Tensor, scores_replicates: torch.Tensor) -> float:
        """
        Computes the pvalue given the score of the true data and the scores of the replicates
        """
        assert scores_replicates.dim()==1
        r=0
        sorted_replicates_scores, _ = torch.sort(scores_replicates, descending=True)
        nb_replicates = sorted_replicates_scores.size()[0]
        for k in range(nb_replicates):
            if score_true <= sorted_replicates_scores[k]:
                r+=1
            else:
                break
        return (r+1)/(nb_replicates+1)

def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, 
        length = 100, fill = '█', printEnd = "\r"):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
        printEnd    - Optional  : end character (e.g. "\r", "\r\n") (Str)
    """
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print(f'\r{prefix} |{bar}|' +str(iteration) +'/'+str(total) + suffix, end = printEnd)
    # Print New Line on Complete
    if iteration == total: 
        print()

################################################################################
# Tests

class TestUtils(TestCase):
    def test_mesh_center(self):
        m = Motif(torch.rand(1,4,4))
        c = compute_mesh_centers((m,), mesh_size=1)[0]
        self.assert_torch_allclose(c, Motif(torch.ones(1,4,4)*0.25))
    def test_compute_pvalue(self):
        a = torch.ones(1)*3
        b = torch.tensor([4,2,5,1,6])
        c = compute_p_value(a,b)
        self.assertEqual(c, 4/6)
    def test_compare_mesh_centers(self):
        m1 = (torch.ones(3), torch.ones(3)*2, torch.ones(3)*3)
        m2 = (torch.ones(3)*3, torch.ones(3), torch.ones(3)*2)
        self.assertFalse(compare_mesh_centers(m1, m2))
        self.assertFalse(compare_mesh_centers(m2, m1))
        m1 = (torch.ones(3), torch.ones(3)*2, torch.ones(3)*3)
        m2 = (torch.ones(3), torch.ones(3)*2, torch.ones(3)*3)
        self.assertTrue(compare_mesh_centers(m1, m2))
        self.assertTrue(compare_mesh_centers(m2, m1))
        m1 = (torch.ones(3), torch.ones(3)*2, torch.ones(3)*3)
        m2 = (torch.ones(3), torch.ones(3)*4, torch.ones(3)*2*5)
        self.assertFalse(compare_mesh_centers(m1, m2))

class TestHitAndRun(TestCase):
    def test_draw_direction(self):
        try:
            shutil.rmtree('tmp_dd')
        except:
            pass
        os.mkdir('tmp_dd')
        hr = HitAndRunSampler(10, 10, 1, 'tmp_dd/')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        generator = torch.distributions.normal.Normal(torch.zeros(dataset.get_labels().get_number_of_samples()), 
                                                    torch.ones(dataset.get_labels().get_number_of_samples())*1)
        hr._draw_random_direction(generator)
        shutil.rmtree('tmp_dd/')
    def test_draw_new_point(self):
        initial_point = Labels(torch.rand(1,1,50,1))
        try:
            shutil.rmtree('tmp_dnp')
        except:
            pass
        os.mkdir('tmp_dnp')
        hr = HitAndRunSampler(10, 10, 1, 'tmp_dnp/')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        generator = torch.distributions.normal.Normal(torch.zeros(dataset.get_labels().get_number_of_samples()), 
                                                    torch.ones(dataset.get_labels().get_number_of_samples())*1)
        direction = hr._draw_random_direction(generator)
        new_point = hr._draw_new_point(initial_point, direction, generator)
        self.assertIsInstance(new_point, Labels)
        self.assertEqual(new_point.size(), initial_point.size())
        shutil.rmtree('tmp_dnp/')
    def test_hit_and_run(self):
        try:
            shutil.rmtree('tmp_hr')
        except:
            pass
        os.mkdir('tmp_hr')
        nb_replicates = 5
        hr = HitAndRunSampler(5, nb_replicates, 1, 'tmp_hr')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        motifs_set = selecter.select_n_motifs(3)
        mesh_size=0.5
        replicates_condition_order, _ = hr(motifs_set, mesh_size, selecter)
        self.assertIsInstance(replicates_condition_order, Labels)
        self.assertEqual(replicates_condition_order.get_number_of_samples(), 50)
        self.assertEqual(replicates_condition_order.get_number_of_labels(), nb_replicates)
        shutil.rmtree('tmp_hr')

class TestInference(TestCase):
    def test_inference_with_hsic_score(self):
        inferer = Inferer('hit_and_run')
        try:
            shutil.rmtree('tmp_inf')
        except:
            pass
        os.mkdir('tmp_inf')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        motifs_set = selecter.select_n_motifs(3)
        pvalues, _ , _= inferer(motifs_set = motifs_set, selecter = selecter, 
            model_noise = 1, location = 'tmp_inf/', mesh_size = 0.5, nb_burn_in=5, nb_replicates=5)
        self.assertEqual(len(pvalues),3)
        shutil.rmtree('tmp_inf')
    def test_inference_with_hsic_score_and_rejection(self):
        inferer = Inferer('rejection')
        try:
            shutil.rmtree('tmp_inf')
        except:
            pass
        os.mkdir('tmp_inf')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (3,))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, hsic_score, gaussian_activation_function)
        motifs_set = selecter.select_n_motifs(1)
        pvalues, _ , _= inferer(motifs_set = motifs_set, selecter = selecter, 
            model_noise = 1, location = 'tmp_inf/', mesh_size = 0.5, nb_replicates=5)
        self.assertEqual(len(pvalues),1)
        shutil.rmtree('tmp_inf')
    def test_inference_with_ridge_score(self):
        inferer = Inferer('hit_and_run')
        try:
            shutil.rmtree('tmp_inf')
        except:
            pass
        os.mkdir('tmp_inf')
        nb_replicates = 5
        hr = HitAndRunSampler(5, nb_replicates, 1, 'tmp_hr')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        ridge_score = AssociationScore('ridge', lmbda_ridge=1)
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        selecter = Selecter(dataset, ridge_score, gaussian_activation_function)
        motifs_set = selecter.select_n_motifs(3)
        pvalues, _ , _= inferer(motifs_set = motifs_set, selecter = selecter, 
            model_noise = 1, location = 'tmp_inf/', mesh_size = 0.5, nb_burn_in=5, nb_replicates=5)
        self.assertEqual(len(pvalues),3)
        shutil.rmtree('tmp_inf')
    def test_inference_data_split(self):
        inferer = Inferer('data_split')
        try:
            shutil.rmtree('tmp_inf')
        except:
            pass
        os.mkdir('tmp_inf')
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        hsic_score = AssociationScore('hsic')
        gaussian_activation_function = ActivationFunction('gaussian_max_pooling')
        train_dset, test_dset = dataset.split(0.8)
        train_selecter = Selecter(train_dset, hsic_score, gaussian_activation_function, 
            max_iteration_steps=100)
        test_selecter = Selecter(test_dset, hsic_score, gaussian_activation_function, 
            max_iteration_steps=100)
        motifs_set = train_selecter.select_n_motifs(3)
        pvalues = inferer(test_selecter = test_selecter, model_noise = 1, 
            nb_replicates = 10, motifs_set=motifs_set)
        shutil.rmtree('tmp_inf')

        

        

