from math import floor
import torch
from typing import Tuple
import Bio.SeqIO.FastaIO
import numpy
import random

from seism.sequences import OneHotEncodedSequences
from seism.labels import Labels
from seism.definitions import TestCase
from seism.kmers import KmerList, kmer_list_from_file_location, kmer_list_from_ohe_sequences_and_lengths

################################################################################
# Dataset Class

class Dataset:
    # A Dataset is a collection of one-hot encoded sequences and their corresponding labels
    def __init__(self, one_hot_encoded_sequences: OneHotEncodedSequences, labels: Labels, 
                kmer_list: KmerList):
        self.one_hot_encoded_sequences = one_hot_encoded_sequences
        self.labels = labels
        self.kmer_list = kmer_list
        assert one_hot_encoded_sequences.get_number() == labels.get_number_of_samples(), (
            'Different number of labels and sequences were given...')
    
    def get_sequences(self):
        return self.one_hot_encoded_sequences
    
    def get_labels(self):
        return self.labels

    def get_kmer_list(self):
        return self.kmer_list

    def split(self, split_ratio: float):
        """
        It takes a dataset and a split ratio, and returns two datasets, each containing 
        a random subset of the original dataset
        """
        number_of_samples = self.get_sequences().get_number()
        random_indices = random.sample(range(number_of_samples), floor(split_ratio*number_of_samples))
        sequences_1 = OneHotEncodedSequences(self.get_sequences()[random_indices,:,:])
        labels_1 = Labels(self.get_labels()[:,:,random_indices,:])
        dataset_1 = Dataset(sequences_1, labels_1, self.kmer_list)
        all_indices = [i for i in range(number_of_samples)]
        complements = [x for x in all_indices if x not in random_indices]
        sequences_2 = OneHotEncodedSequences(self.get_sequences()[complements,:,:])
        labels_2 = Labels(self.get_labels()[:,:,complements,:])
        dataset_2=Dataset(sequences_2, labels_2, self.kmer_list)
        return dataset_1, dataset_2

def dataset_from_sequences_and_labels(one_hot_encoded_sequences: OneHotEncodedSequences, 
        labels: Labels, kmer_lengths: Tuple[int]):
    kmer_list = kmer_list_from_ohe_sequences_and_lengths(one_hot_encoded_sequences, kmer_lengths)
    return Dataset(one_hot_encoded_sequences, labels, kmer_list)

def parse_fasta_dataset(file_location: str, kmer_lengths: Tuple[int]) -> Dataset:
    """
    It takes a file_location (str), a tuple of lengths and returns a Dataset
    param file_location: the location to the FASTA file
    """
    kmer_list = kmer_list_from_file_location(file_location, kmer_lengths)
    file_handle = open(file_location, 'r')
    sequences_tensor = None
    labels_tensor = []
    byte_to_index = bytearray([0xff] * 256)
    valid_chars = b"ACGT"
    for (i,c) in enumerate (valid_chars):
        byte_to_index[c] = i
    nb_indexes = 4
    lookup_table = torch.zeros(nb_indexes,nb_indexes)
    for (index, position) in enumerate (list(range(nb_indexes))):
        assert 0 <= position < nb_indexes, "Out of bound position"
        lookup_table[index, position] =1

    sequences_set = ()
    original_lengths = []
    for title, sequence in Bio.SeqIO.FastaIO.FastaTwoLineParser(file_handle):
        labels_tensor += [float(title.split()[1])]
        seq = torch.as_tensor(numpy.array(memoryview(sequence.encode('utf-8').translate(
                byte_to_index)))).unsqueeze(0)
        original_lengths += [seq.size()[1]]
        sequences_set +=  (seq,)
    original_lengths = torch.tensor(original_lengths)
    max_length = int(original_lengths.max())
    sequences_tensor = torch.empty(0, max_length)
    for s in sequences_set:
        if s.size()[1]<max_length:
            s = torch.cat((s, torch.zeros(max_length-seq.size()[1]).unsqueeze(0)), dim=1)
        sequences_tensor = torch.cat((sequences_tensor, s), dim=0)

    labels_tensor = torch.tensor(labels_tensor)
    labels_tensor = labels_tensor - torch.mean(labels_tensor)
    labels_tensor = labels_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
    labels = Labels(labels_tensor)
    sequences_tensor = torch.nn.functional.embedding(sequences_tensor.to(dtype = torch.long), 
        lookup_table).transpose(dim0=1, dim1=2)
    one_hot_encoded_sequences = OneHotEncodedSequences(sequences_tensor, original_lengths)
    file_handle.close()
    return Dataset(one_hot_encoded_sequences, labels, kmer_list)

################################################################################
# Tests

class TestDataset(TestCase):
    def test_parsing(self): 
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        self.assertIsInstance(dataset, Dataset)

    def test_splitting(self):
        dataset = parse_fasta_dataset('test/test_files/50_seq_file.fasta', (10,))
        d1, d2 = dataset.split(0.1)
        self.assertEqual(d1.get_labels().get_number_of_samples(), 0.1*50) 
        self.assertEqual(d1.get_sequences().get_number(), 0.1*50)
        self.assertEqual(d2.get_labels().get_number_of_samples(), 0.9*50) 
        self.assertEqual(d2.get_sequences().get_number(), 0.9*50)
