# %% region import statements
import sys
import csv
import json 
import time
import os 

import numpy as np
import random

from utils.data_functions import shuffle_data, write_data_to_file, get_cluster_size_distribution_cpred
from utils.noisy_dna import generate_noisy_dna_cpred, file_to_list, fastq_to_list
from utils.helper_functions import create_folder, get_now_str
from utils.print_functions import print_dict
# end region

# %% ################################# general variables #################################
script_dir = os.path.dirname(__file__)
print("script_dir: ", script_dir)

save_flag = True
seq_len_lb = 55
seq_len_ub = 70
seed_number = 42

lev_dist_lb = 5
lev_dist_ub = 13

random.seed(seed_number)
np.random.seed(seed_number)

cluster_case = "SC" # or LC (SC: small clusters, LC: large clusters)
save_flag = True # save data
test = True

if cluster_case == "SC":
    lb = 2
    ub = 5
    test_train_cut = 500
elif cluster_case == "LC":
    lb = 6
    ub = 10
    test_train_cut = 1000 # number of test clusters
else:
    raise ValueError('Invalid cluster_case value')
# end region

# %% ################################## Load data ##################################

##### Load synthesized data #####
orig_seqs = file_to_list("data/File1_ODNA.txt")
filename = "data/I16_S2_R1_001.fastq"
seqs = fastq_to_list(filename)
print("all sequences: ", len(seqs))
print("all orig sequences: ", len(orig_seqs))
reads = [seq for seq in seqs if len(seq) >= seq_len_lb and len(seq)<=seq_len_ub]
print("all trimmed sequences: ",  len(reads))

if test == True: 
    index_clusters = 'data/index_clusters_test.csv'
else:
    index_clusters = 'data/index_clusters.csv'

##### Load the clusters #####   
clusters = []
with open(index_clusters, 'r') as file:
    reader = csv.reader(file)
    for row in reader:
        # Convert each element in the row to an integer
        cluster = [int(x) for x in row] if row else []
        # Append the list to clusters list
        clusters.append(cluster)

reads_cut = len(reads)
reads = reads[:reads_cut]

orig_seqs_cut = len(clusters)
orig_seqs = orig_seqs[:orig_seqs_cut]

now_str = get_now_str()
run_name = f'filtered_clusters_lev_lb_{lev_dist_lb}_ub_{lev_dist_ub}_now_{now_str}' 

config = {}
config['seq_len_lb'] = seq_len_lb
config['seq_len_ub'] = seq_len_ub

config['orig_seqs_cut']  = orig_seqs_cut
config['reads_cut']      = reads_cut
config['test_train_cut'] = test_train_cut

config['lev_dist_lb'] = lev_dist_lb
config['lev_dist_ub'] = lev_dist_ub

config['wandb_run_name'] = run_name
config['seed_number'] = seed_number

create_folder(f'wandb/{run_name}')

# Save config as JSON
config_json_path = f'data/config_filter_index_clusters_{cluster_case}.json'
if save_flag:
    with open(config_json_path, 'w') as json_file:
        json.dump(config, json_file, indent=4)

#sys.exit()


# %% ################################## Filter clusters ##################################
start_time = time.time()

print(orig_seqs[0])
orig_seqs, clusters = shuffle_data(orig_seqs, clusters)
print(orig_seqs[0])

test_orig_seqs = orig_seqs[:test_train_cut]
test_clusters = clusters[:test_train_cut]

train_val_orig_seqs = orig_seqs[test_train_cut:]
train_val_clusters = clusters[test_train_cut:]

train_val_list, _ = generate_noisy_dna_cpred(orig_seqs = train_val_orig_seqs, 
                                             clusters = train_val_clusters,
                                             ub = ub, lb = lb,
                                             lev_dist_ub = lev_dist_ub, lev_dist_lb = lev_dist_lb,
                                             reads = reads, remove_trailing_C_flag = False)

test_reads = []
for test_cluster in test_clusters:
    test_reads += [reads[i] for i in test_cluster]

end_time = time.time()
time_taken = end_time - start_time
print('time taken:', time_taken)   
# end region

if save_flag:
    write_data_to_file(f'data/train_val_data_{cluster_case}.txt', train_val_list)
    write_data_to_file(f'data/test_reads_{cluster_case}.txt', test_reads)

train_val_dist = get_cluster_size_distribution_cpred(train_val_list)
print('train_val_dist:')
print_dict(train_val_dist)
