import Levenshtein
import numpy as np
import random

def remove_trailing_Cs(input_string):
    return input_string.rstrip('C')

def generate_noisy_dna_cpred(orig_seqs, clusters, ub=-1, lb=-1, lev_dist_ub=-1, lev_dist_lb=-1, reads = [], remove_trailing_C_flag = False):

    if ub == -1 or lb == -1 or lev_dist_ub == -1 or lev_dist_lb == -1:
        raise ValueError('Please provide values for ub, lb, lev_dist_ub, lev')
    
    if reads == []:
        raise ValueError('Please provide a list of reads')

    data_list = []
    ground_truth_list = []
    print_flag = False
    counter = 0

    for orig_seq, cluster in zip(orig_seqs, clusters):

        if counter % 1000 == 0:
            print('counter: ', counter)
        counter += 1

        if print_flag:
            print('#######################################################################')
            print('#######################################################################')
            print(orig_seq)
            print('-----------------------------------------------------------------------')

        filtered_cluster = []
        
        for seq_index in cluster:

            seq = reads[seq_index]
            if remove_trailing_C_flag:
                seq = remove_trailing_Cs(seq)
            
            lev_dist = Levenshtein.distance(orig_seq, seq)
            if print_flag:
                print(seq)
                print('levenshtein distance: ', lev_dist)
            if lev_dist <= lev_dist_ub and lev_dist >= lev_dist_lb:
                filtered_cluster.append(seq)
        
        random.shuffle(filtered_cluster)    
        
        if len(filtered_cluster) > ub:
            sublists = [filtered_cluster[i:i + ub] for i in range(0, len(filtered_cluster), ub)]
            
            for sublist in sublists:
                if len(sublist) >= lb:
                    data_example = '|'.join(sublist) + ':' + orig_seq
                    data_list.append(data_example)
                    ground_truth_list.append(orig_seq)
                    
        else:
            if len(filtered_cluster) >= lb:
                data_example = '|'.join(filtered_cluster) + ':' + orig_seq
                data_list.append(data_example)
                ground_truth_list.append(orig_seq)

    return data_list, ground_truth_list

def file_to_list(filename):
    seqs = []
    with open(filename) as f:
        for line in f:
            seqs += [line[0:-1]]
    return seqs

def list_to_file(strings, filename):
    with open(filename, 'w') as f:
        for string in strings:
            f.write(string + '\n')
                      
def fastq_to_list(fastq_filename):
    seqs = []
    with open(fastq_filename) as f:
        for i,line in enumerate(f):
            if (i + 3) % 4 == 0:
                seqs += [line[0:-1]]
    return seqs

def seq_stats(seqs):
    length = np.zeros(200)
    nucleotides = np.zeros(4)
    total = 0
    ctr = 0
    for seq in seqs:
        ctr += 1
        length[len(seq)] += 1
        nucleotides[0] += seq.count('A')
        nucleotides[1] += seq.count('C')
        nucleotides[2] += seq.count('G')
        nucleotides[3] += seq.count('T')
        total += len(seq)
    return length, nucleotides/total,ctr

def fraction_recovered(candidates,orig_seqs):
    d = {}
    for seq in orig_seqs:
        d[seq] = 0
    for cand in candidates:
        if cand in d:
            d[cand] += 1
    av = sum([ d[seq]>0 for seq in d]) / len(d)
    print("Fraction of recovered sequences: ", av )
    if av>0:
        print("Fraction of recovered sequences: ", sum([ d[seq] for seq in d]) / len(d) / av )