# %% imports
import json
import sys
import os
import csv

import time

from utils.helper_functions import create_folder, get_now_str
from utils.noisy_dna import file_to_list, fastq_to_list
from utils.data_functions import filter_string

# %% ################################## Load data ##################################
script_dir = os.path.dirname(__file__)
print("script_dir: ", script_dir)

print_flag = True
test = True
wandb_log = False

# for filtering by the sequence length
seq_len_lb = 55 
seq_len_ub = 70

# for searching the sequence index
sub_read_lb = 30
sub_read_ub = 70

orig_seqs = file_to_list("data/File1_ODNA.txt")
filename = "data/I16_S2_R1_001.fastq"
seqs = fastq_to_list(filename)
print("all sequences: ", len(seqs))
print("all orig sequences: ", len(orig_seqs))
reads = [seq for seq in seqs if len(seq) >= seq_len_lb and len(seq)<= seq_len_ub]
print("all trimmed sequences: ",  len(reads))

if test == True:    
    orig_seqs_size = 2000 # for testing purposes 
    reads_size = len(reads)
else:
    orig_seqs_size = len(orig_seqs)
    reads_size = len(reads)

orig_seqs = orig_seqs[:orig_seqs_size]
reads = reads[:reads_size]
indices = [seq[42:54] for seq in orig_seqs]

print('orig_seqs_size: ', orig_seqs_size)
print('reads_size: ', reads_size)

# %% ################################## config & wandb ##################################
config = {}

add_info = ''
config['print_flag'] = print_flag
config['add_info'] = add_info
config['reads_size'] = reads_size
config['orig_seqs_size'] = orig_seqs_size  
config['seq_len_lb'] = seq_len_lb
config['seq_len_ub'] = seq_len_ub

config['sub_read_lb'] = sub_read_lb
config['sub_read_ub'] = sub_read_ub

now_str = get_now_str()
run_name = f'index_clustering_{now_str}' 
print('run_name: ', run_name)

create_folder(run_name)
# Save config as JSON
if test == True:    
    config_json_path = f'data/config_index_clustering_test.json'
else:
    config_json_path = f'data/config_index_clustering.json'
    
with open(config_json_path, 'w') as json_file:
    json.dump(config, json_file, indent=4)


# %% ################################## Clustering ##################################
start_time = time.time()
clusters = {}

for index_number, index in enumerate(indices):
    clusters[index] = []

for index_number, index in enumerate(indices):
    for i, read in enumerate(reads):
        if i % 10**6 == 0 and print_flag:
            print(f'index number {index_number}, read number: {i:.2e}')

        sub_read = read[sub_read_lb:sub_read_ub]
        if index in sub_read:
                clusters[index] += [i]

end_time = time.time()
time_taken = end_time - start_time
print('time taken:', time_taken)

# %% ################################## Save results ##################################

if test == True:
    with open(f'data/index_clusters_test.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        for value in clusters.values():
            writer.writerow(value)

else:     
    with open(f'data/index_clusters.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        for value in clusters.values():
            writer.writerow(value)
