#%%
import sys

import Levenshtein
from utils.print_functions import print_list
from utils.data_functions import write_data_to_file
from utils.noisy_dna import file_to_list


#%%
def prepare_clu_data(data_dict, cluster_size):
    result = []
    
    for centroid, (total_reads, reads_list) in data_dict.items():
        for i in range(0, len(reads_list) - cluster_size + 1, cluster_size):
            sublist = reads_list[i:i + cluster_size]
            result.append(sublist)

    return result

#%%
save_flag = True
cluster_case = 'SC'
if cluster_case == "SC":
    cluster_size = 5
elif cluster_case == "LC":
    cluster_size = 10
else:
    raise ValueError('Invalid cluster_case value')

file_path = f'data/starcode_clusters_{cluster_case}.txt'
orig_seqs = file_to_list("data/File1_ODNA.txt")

# Initialize an empty dictionary to store the data
data_dict = {}
# Open and read the file
with open(file_path, 'r') as file:
    for line in file:
        # Split the line by tab
        parts = line.strip().split('\t')
        centroid = parts[0]
        total_reads = int(parts[1])
        reads_list = parts[2].split(',')
        
        # Store in the dictionary
        data_dict[centroid] = (total_reads, reads_list)

# Print the dictionary to verify
#print(data_dict)
#sys.exit()

#%%
result = prepare_clu_data(data_dict, cluster_size)
    
list_length = len(result)
#print('list_length: ', list_length)

cpred_data = []
overall_dist_list = []

# find closest ground truth for each cluster
for list_index, list in enumerate(result):
    #print('list_index: ', list_index)
    dist_list = []  
    for seq_index, orig_seq in enumerate(orig_seqs):
        avg_dist = 0
        for list_elem in list:
            dist = Levenshtein.distance(orig_seq, list_elem)
            avg_dist += dist
        avg_dist = avg_dist/len(list)
        dist_list.append(avg_dist)
    min_index = dist_list.index(min(dist_list))
    overall_dist_list.append(min(dist_list))
    #print(min(dist_list))
    ground_truth = orig_seqs[min_index]
    obs_str  = '|'.join(list)
    cpred_data.append(obs_str +':'+ ground_truth)

if save_flag:
    write_data_to_file(f'data/starcode_test_cpred_data_{cluster_case}.txt', cpred_data)

min_dist = min(overall_dist_list)
max_dist = max(overall_dist_list)

print('min_dist: ', min_dist)
print('max_dist: ', max_dist)



