import argparse
import os
from typing import Tuple
import torch
import numpy as np

from seism.motif import Motif, ppm_to_bits
from seism.definitions import TestCase
from seism.sequences import OneHotEncodedSequences
from seism.labels import Labels
from seism.activation_function import ActivationFunction
from seism.association_score import AssociationScore
from seism.dataset import Dataset, parse_fasta_dataset
from seism.selection import Selecter
from seism.inference import Inferer
from seism.kmers import kmer_list_from_ohe_sequences_and_lengths

class AnalysisConfiguration:
    """
    Stores the overall configuration for the analysis
    """
    def __init__(self, argument_values = argparse.Namespace):
        self.location = os.path.abspath(argument_values.output)
        self.fasta_file_location = argument_values.input
        self.nb_motifs = argument_values.nb_motifs
        self.min_motifs_length = argument_values.min_motifs_length
        self.max_motifs_length = argument_values.max_motifs_length
        self.association_score = argument_values.association_score
        self.ridge_lambda = argument_values.ridge_lambda
        self.perform_inference = argument_values.perform_inference
        self.inferer_type = argument_values.inferer_type
        self.nb_threads = argument_values.nb_threads
        self.hr_nb_burn_in = argument_values.hr_nb_burn_in
        self.hr_nb_replicates = argument_values.hr_nb_replicates
        self.hr_mesh_size = argument_values.hr_mesh_size
        self.rs_nb_replicates = argument_values.rs_nb_replicates
        self.rs_mesh_size = argument_values.rs_mesh_size
        self.ds_nb_replicates = argument_values.ds_nb_replicates
        self.ds_split_ratio = argument_values.ds_split_ratio

def do_analysis(argument_values: argparse.Namespace):
    config = AnalysisConfiguration(argument_values)
    torch.set_num_threads(config.nb_threads)

    try:
        os.mkdir(config.location)
    except:
        assert False, 'This directory already exists'

    torch.save(argument_values, config.location + "/config.pt")

    # Parameters definition
    association_score = AssociationScore(config.association_score,
            config.ridge_lambda)

    activation_function = ActivationFunction('gaussian_max_pooling')

    fasta_dataset = parse_fasta_dataset(config.fasta_file_location, tuple([i
        for i in range(config.min_motifs_length, config.max_motifs_length+1)]))

    selecter = Selecter(dataset = fasta_dataset, association_score =
            association_score, activation_function = activation_function)

    sigma = torch.std(torch.cat((fasta_dataset.get_labels().view(-1),
        -fasta_dataset.get_labels().view(-1)), dim=0))

    # Motifs selection and test

    inferer = Inferer(config.inferer_type)

    if config.inferer_type == 'hit_and_run':
        selected_motifs = selecter.select_n_motifs(config.nb_motifs)
        if config.perform_inference:
            pvalues_center, pvalues_best_motifs_in_mesh, replicates = inferer(
                    motifs_set = selected_motifs, selecter = selecter, model_noise = sigma,
                    location = config.location, mesh_size = config.hr_mesh_size, 
                    nb_burn_in = config.hr_nb_burn_in, nb_replicates = config.hr_nb_replicates)

            torch.save(pvalues_center, config.location + "/pvalues_center_psi.pt")   
            torch.save(pvalues_best_motifs_in_mesh, config.location + "/pvalues_best_motifs_in_mesh_psi.pt")            
            torch.save(torch.Tensor(replicates), config.location +"/replicates_psi.pt")  
        
    
    elif config.inferer_type == 'data_split':
        training_dset, test_dset = fasta_dataset.split(config.ds_split_ratio)
        train_selecter = Selecter(training_dset, selecter.association_score, selecter.activation_function)
        selected_motifs = train_selecter.select_n_motifs(config.nb_motifs)  
        test_selecter = Selecter(test_dset, association_score, selecter.activation_function)
        if config.perform_inference:
            pvalues, replicates = inferer(test_selecter = test_selecter, model_noise = sigma, 
                    nb_replicates = config.ds_nb_replicates, motifs_set = selected_motifs)
            
            torch.save(pvalues, config.location + "/pvalues_data_split.pt")
            torch.save(torch.Tensor(replicates), config.location +"/replicates_data_split.pt")
    
    elif config.inferer_type == 'rejection':
        print('DO NOT USE REJECTION INFERENCE, ONLY MAINTAINED FOR DEVELOPMENT')
        selected_motifs = selecter.select_n_motifs(config.nb_motifs)
        if config.perform_inference:
            pvalues_center, pvalues_best_motifs_in_mesh, replicates = inferer(motifs_set = selected_motifs, 
                    selecter = selecter, model_noise = sigma, location = config.location, 
                    mesh_size = config.rs_mesh_size, nb_replicates = config.rs_nb_replicates)

            torch.save(pvalues_center, config.location + "/pvalues_center_rejection.pt")   
            torch.save(pvalues_best_motifs_in_mesh, config.location + "/pvalues_best_motifs_in_mesh_rejection.pt")            
            torch.save(torch.Tensor(replicates), config.location +"/replicates_rejection.pt")
            
    for i in range(len(selected_motifs)):
        selected_motifs[i].project_on_simplex().draw_to_file(config.location +
                '/motif_'+str(i))
        selected_motifs[i].get_reverse_complement().project_on_simplex().draw_to_file(config.location +
                '/motif_rc'+str(i))
        torch.save(torch.Tensor(selected_motifs[i].project_on_simplex()), config.location + '/motif_'+str(i)+'.pt')
