import torch
from typing import Tuple
import subprocess
import numpy
import tempfile

from seism.motif import Motif
from seism.sequences import OneHotEncodedSequences
from seism.definitions import TestCase

################################################################################
# KmerList Class

class KmerList(Tuple[Motif]):
    # A KmerList is a tuple of Motifs
    def __init__(self, tuple_motifs = Tuple[Motif]):
        assert isinstance(tuple_motifs, Tuple), 'KmerList has to be a Tuple'
        assert min([isinstance(tuple_motifs[i], Motif) for i in range (len(tuple_motifs))]), (
            'KmerList has to be a Tuple[Motif]')
        super().__init__()

def kmer_list_from_ohe_sequences_and_lengths(ohe_sequences: OneHotEncodedSequences, 
        possible_lengths: Tuple[int]) -> KmerList:
    """
    Given a set of one hot encoded sequences, and a set of possible lengths, returns a list of kmers
    param ohe_sequences: OneHotEncodedSequences where you want to find kmers in
    param possible_lengths: Tuple[int] the kmer sizes to use
    """
    return KmerList(tuple([Motif(torch.unique(torch.stack([ohe_sequences[:,:,k:k+length] 
        for k in range (ohe_sequences.get_length()-length+1)], dim=0).view(
            -1, 4, length), dim=0)) for length in possible_lengths]))

def kmer_list_from_file_location(fasta_file_location: str, kmer_sizes=Tuple[int]) -> KmerList:
    """
    The function takes a fasta file location and a tuple of kmer sizes and returns a KmerList
    param fasta_file_location: The location of the fasta file containing 
        the sequences you want to find kmers in 
    param kmer_sizes: the kmer sizes to use
    """
    with tempfile.TemporaryDirectory('.tmp_dsk') as directory:
        kmer_list = ()
        for kmer_size in kmer_sizes:
            subprocess.run(["dsk -verbose 0 -file " + fasta_file_location + " -kmer-size " + 
                str(kmer_size) + " -out "+str(directory)+"/kmer_file_length_" + 
                str(kmer_size)+".h5"], shell=True, stdout=subprocess.DEVNULL)

            subprocess.run(["dsk2ascii -verbose 0 -file "+ directory+"/kmer_file_length_" + 
                str(kmer_size)+".h5 -out "+str(directory)+"/kmer_file_ascii_length_" + 
                str(kmer_size)+".h5"], shell = True, stdout=subprocess.DEVNULL ) 

            kmer_file = open(directory+"/kmer_file_ascii_length_"+str(kmer_size)+".h5", "r")
            byte_to_index = bytearray([0xff] * 256)
            valid_chars = b"ACGT"
            for (i,c) in enumerate (valid_chars):
                byte_to_index[c] = i
            nb_indexes = 4
            lookup_table = torch.zeros(nb_indexes,nb_indexes)
            for (index, position) in enumerate (list(range(nb_indexes))):
                assert 0 <= position < nb_indexes, "Out of bound position"
                lookup_table[index, position] =1
            
            all_kmers = torch.empty(0, 4, kmer_size)
            numbers = []
            for lines in kmer_file:
                kmer_tensor = numpy.array(memoryview(lines.split()[0].encode('utf-8').translate(byte_to_index)))
                kmer_tensor = torch.as_tensor(kmer_tensor).unsqueeze(0)
                kmer = torch.nn.functional.embedding(kmer_tensor.to(dtype = torch.long), 
                    lookup_table).transpose(dim0=1, dim1=2)
                all_kmers = torch.cat((all_kmers, kmer), dim=0)
                numbers += [float(lines.split()[1])]
            numbers = torch.as_tensor(numbers)
            kmer_file.close()
            if len(numbers)>0:
                kmers_for_size = all_kmers[numbers>=torch.quantile(numbers, 0.95),:,:]
                kmer_list += (Motif(kmers_for_size),)
    return KmerList(kmer_list)
    
################################################################################
# Tests
class TestKmers(TestCase):
    def test_construction(self):
        a = torch.zeros(3)
        with self.assertRaises(AssertionError):
            KmerList(a)
        b = (Motif(torch.rand(1,4,1)),torch.zeros(3)) # not a Tuple[Motif]
        with self.assertRaises(AssertionError):
            KmerList(b)
        KmerList((Motif(torch.zeros(1,4,3)), ))
        
    def test_kmer_list_from_sequences(self):
        s_tensor = torch.zeros(1,4,3)
        s_tensor[0,0,0]=1
        s_tensor[0,0,1]=1
        s_tensor[0,1,2]=1
        s = OneHotEncodedSequences(s_tensor) #AAC
        lengths=(1,2)
        kmer_list = kmer_list_from_ohe_sequences_and_lengths(s, lengths)
        self.assertIsInstance(kmer_list, KmerList)
        self.assertEqual(len(kmer_list),2)
        self.assertEqual(kmer_list[0].size()[0],2)
        self.assertEqual(kmer_list[0].size()[2],1)
        self.assertEqual(kmer_list[1].size()[0],2)
        self.assertEqual(kmer_list[1].size()[2],2)

    def test_kmer_list_from_file_location(self):
        kmer_list = kmer_list_from_file_location('test/test_files/50_seq_file.fasta', (5,6))
        self.assertIsInstance(kmer_list, KmerList)
