import argparse
import os
from typing import Tuple
import torch
import numpy as np

from seism.motif import Motif, ppm_to_bits
from seism.definitions import TestCase
from seism.sequences import OneHotEncodedSequences
from seism.labels import Labels
from seism.activation_function import ActivationFunction
from seism.association_score import AssociationScore
from seism.dataset import Dataset
from seism.selection import Selecter
from seism.inference import Inferer
from seism.kmers import kmer_list_from_ohe_sequences_and_lengths

class SimulationConfiguration:
    """
    Stores the overall configuration for the simulation
    """
    def __init__(self, argument_values = argparse.Namespace):
        self.is_calibration = argument_values.is_calibration
        self.location = os.path.abspath(argument_values.output)
        self.nb_motifs = argument_values.nb_motifs
        self.nb_samples = argument_values.nb_samples
        self.model_noise = argument_values.model_noise
        self.sequences_length = argument_values.sequences_length
        self.min_motifs_length = argument_values.min_motifs_length
        self.max_motifs_length = argument_values.max_motifs_length
        self.nb_true_motifs = argument_values.nb_true_motifs
        self.true_motifs_length = argument_values.true_motifs_length
        self.true_motifs_information = argument_values.true_motifs_information
        self.association_score = argument_values.association_score
        self.ridge_lambda = argument_values.ridge_lambda
        self.inferer_type = argument_values.inferer_type
        self.nb_threads = argument_values.nb_threads
        self.hr_nb_burn_in = argument_values.hr_nb_burn_in
        self.hr_nb_replicates = argument_values.hr_nb_replicates
        self.hr_mesh_size = argument_values.hr_mesh_size
        self.rs_nb_replicates = argument_values.rs_nb_replicates
        self.rs_mesh_size = argument_values.rs_mesh_size
        self.ds_nb_replicates = argument_values.ds_nb_replicates
        self.ds_split_ratio = argument_values.ds_split_ratio
        
def do_simulation(argument_values: argparse.Namespace):
    config = SimulationConfiguration(argument_values)
    torch.set_num_threads(config.nb_threads)
    try:
        os.mkdir(config.location)
    except:
        assert False, 'This directory already exists'
    torch.save(argument_values, config.location + "/config.pt")

    # Simulated dataset generation
    true_motifs = ()
    for i in range(config.nb_true_motifs):
        true_motifs += (generate_random_motif(config.true_motifs_length,
            config.true_motifs_information),)
        true_motifs[i].draw_to_file(config.location + '/true_motif_'+str(i+1))
    sequences, _ = generate_positive_and_negative_sequences(true_motifs, 
        config.nb_samples, config.sequences_length)

    activation_function = ActivationFunction('gaussian_max_pooling')
    activation_vectors = ()
    for i in range(config.nb_true_motifs):
        activation_vectors += (activation_function(true_motifs[i], sequences),)

    mu = Labels(torch.zeros(1, 1, config.nb_samples, 1))
    if not config.is_calibration:
        for i in range(config.nb_true_motifs):
            mu += Labels(activation_vectors[i])
        mu = Labels(mu.view(-1).unsqueeze(0).unsqueeze(0).unsqueeze(-1))
    else:
        assert mu.equal(Labels(torch.zeros(1, 1, config.nb_samples, 1))), (
            'Calibration test but mu is not zero !')
    gaussian_noise_generator =  torch.distributions.normal.Normal(torch.zeros(1, 
        1, config.nb_samples, 1), torch.ones(1, 1, config.nb_samples, 1)*config.model_noise)
    initial_label = Labels(mu+gaussian_noise_generator.sample())
    initial_label = initial_label/torch.norm(initial_label)*(config.nb_samples**(1/2))
    torch.save(torch.Tensor(initial_label), config.location + "/labels.pt")

    motifs_length = tuple([i for i in range(config.min_motifs_length, config.max_motifs_length+1)])
    dataset = Dataset(sequences, initial_label, 
        kmer_list_from_ohe_sequences_and_lengths(sequences, motifs_length))
    association_score = AssociationScore(config.association_score, config.ridge_lambda)
    
    selecter = Selecter(dataset, association_score, activation_function)
    inferer = Inferer(config.inferer_type)
    model_noise_for_selecter = torch.std(torch.cat((initial_label.view(-1), -initial_label.view(-1)), dim=0))

    if config.inferer_type == "hit_and_run":
        selected_motifs = selecter.select_n_motifs(config.nb_motifs)  
        pvalues_center, pvalues_best_motifs_in_mesh, replicates = inferer(motifs_set = selected_motifs,
            selecter = selecter, model_noise = model_noise_for_selecter, location = config.location, 
            mesh_size = config.hr_mesh_size, nb_burn_in = config.hr_nb_burn_in, 
            nb_replicates = config.hr_nb_replicates)
        torch.save(pvalues_center, config.location + "/pvalues_center_psi.pt")   
        torch.save(pvalues_best_motifs_in_mesh, config.location + "/pvalues_best_motifs_in_mesh_psi.pt")            
        torch.save(torch.Tensor(replicates), config.location +"/replicates_psi.pt")  

    elif config.inferer_type == "data_split":
        training_dset, test_dset = dataset.split(config.ds_split_ratio)
        train_selecter = Selecter(training_dset, association_score, activation_function)
        selected_motifs = train_selecter.select_n_motifs(config.nb_motifs)  
        test_selecter = Selecter(test_dset, association_score, activation_function)
        pvalues, replicates = inferer(test_selecter = test_selecter,
                model_noise = model_noise_for_selecter, 
                nb_replicates = config.ds_nb_replicates, motifs_set = selected_motifs)
        torch.save(pvalues, config.location + "/pvalues_data_split.pt")
        torch.save(torch.Tensor(replicates), config.location +"/replicates_data_split.pt")  

    elif config.inferer_type == "rejection":
        print('DO NOT USE REJECTION INFERENCE, ONLY MAINTAINED FOR DEVELOPMENT')
        selected_motifs = selecter.select_n_motifs(config.nb_motifs)  
        pvalues_center, pvalues_best_motifs_in_mesh, replicates = inferer(motifs_set = selected_motifs, 
            selecter = selecter, model_noise = model_noise_for_selecter,
            location = config.location, mesh_size = config.rs_mesh_size, 
            nb_replicates = config.rs_nb_replicates)
        torch.save(pvalues_center, config.location + "/pvalues_center_rejection.pt")   
        torch.save(pvalues_best_motifs_in_mesh, config.location + "/pvalues_best_motifs_in_mesh_rejection.pt")            
        torch.save(torch.Tensor(replicates), config.location + "/replicates_rejection.pt")
    
    for i in range(len(selected_motifs)):
        selected_motifs[i].project_on_simplex().draw_to_file(config.location +
            '/motif_'+str(i))
        selected_motifs[i].get_reverse_complement().project_on_simplex().draw_to_file(config.location +
            '/motif_rc'+str(i))
        torch.save(torch.Tensor(selected_motifs[i].project_on_simplex()), config.location+'/motif_'+str(i)+'.pt')

################################################################################
# Utils
    
def generate_random_motif(motif_length: int, information: float, concentration: float = 0.1) -> Motif:
    """
    Generate a random motif of length motif_length, such that the information content of the motif
    is information
    """
    info = -100
    while (abs(info-information)>0.001):
        m = torch.distributions.dirichlet.Dirichlet(torch.ones(4)*concentration)

        ppm = torch.zeros(1, 4, motif_length)
        for i in range (motif_length):
            ppm[0, :, i] = m.sample().view(-1)

        motif_for_information = Motif(ppm).project_on_simplex()

        info = torch.sum(ppm_to_bits(motif_for_information))

        if info>information:
            concentration = concentration * 1.1
        else : 
            concentration = concentration / 1.1
    assert isinstance(motif_for_information, Motif), 'the class of the generated is not Motif'
    return motif_for_information

def generate_random_sequences(sequence_length: int) -> OneHotEncodedSequences:
    """
    Generate a random sequence of length sequence_length, where each position is randomly assigned a
    value of A, C, G or T
    """
    sequence = torch.zeros(1, 4, sequence_length)
    for i in range (sequence_length):
        pos = int(np.floor(np.random.random()*4))
        sequence[0,pos,i]=1
    return(OneHotEncodedSequences(sequence))  

def random_kmer_from_motif(motif: Motif) -> OneHotEncodedSequences:
    """
    Given a motif, randomly select a kmer from the motif (categorical distribution)
    """
    kmer = torch.zeros(motif.size())
    for i in range(kmer.size()[2]):
        rand = np.random.random()
        pos = 0
        total = motif[0, pos, i].clone()
        while rand > total:
            pos += 1
            total += motif[0,pos, i].clone()
        kmer[0, pos, i] = 1
    return(OneHotEncodedSequences(kmer))

def generate_positive_and_negative_sequences(true_motif_tuple: Tuple[Motif], nb_sequences: int = 50, sequences_length: int = 100, 
            positivity_probability: float = 0.5) -> Tuple[OneHotEncodedSequences, list]:
    """
    Given a tuple of motifs, the number of sequences to generate, the length of the sequences, and the
    probability of having a positive sequence,generate a set of sequences with the motifs in the positive ones.
    """
    sequences = OneHotEncodedSequences(torch.empty(0, 4, sequences_length))
    positifs=[]
    for i in range (nb_sequences):
        p = np.random.rand()
        sequences =  torch.cat((sequences, generate_random_sequences(sequences_length)), dim=0)
        nb_true_motifs = len(true_motif_tuple)
        if p < positivity_probability:
            positifs += [i]
            true_motif = true_motif_tuple[int(p*nb_true_motifs//positivity_probability)]
            true_motif_length = true_motif.get_length()
            position = np.random.randint(1, sequences_length-true_motif_length)
            while (position < true_motif_length or position > (sequences_length - true_motif_length)):
                position = np.random.randint(1, sequences_length-true_motif_length)
            sequences[i,:,position : position + true_motif_length] =  random_kmer_from_motif(true_motif) 
    return OneHotEncodedSequences(sequences), positifs
