
#%%
import os

import random
from collections import defaultdict
from collections import Counter

from utils.data_functions import write_data_to_file, read_clusters, shuffle_data, get_cluster_size_distribution_cpred, count_obs_seqs
from utils.helper_functions import create_folder
from utils.microsoft_data import read_centers, generate_cpred_data

if __name__ == '__main__':

    seed_number = 42
    random.seed(seed_number)

    cluster_case = "LC" # or LC (SC: small clusters, LC: large clusters)
    save_flag = True # save data

    if cluster_case == "SC":
        lb = 2
        ub = 5
        test_size = 500
    elif cluster_case == "LC":
        lb = 6
        ub = 10
        test_size = 1000 # number of test clusters
    else:
        raise ValueError('Invalid cluster_case value')

    num_vals = ub - lb + 1

    # region LOAD DATA
    current_file_path = os.getcwd()
    print(current_file_path)
    centers_file_name = 'Centers.txt'
    clusters_file_name = 'Clusters.txt'

    centers_file_path = os.path.join(current_file_path,'data', centers_file_name)
    clusters_file_path = os.path.join(current_file_path,'data', clusters_file_name)

    ground_truth_seqs = read_centers(centers_file_path)
    clusters = read_clusters(clusters_file_path)

    len_clusts = len(clusters)
    print('len of clusters:', len_clusts)
    len_ground_truth_seqs = len(ground_truth_seqs)
    print('len of centers:', len_ground_truth_seqs)
    
    if len_ground_truth_seqs != len_clusts:
        raise ValueError('Number of clusters and centers are not equal')
    
    # region SPLIT DATA
    ground_truth_seqs, clusters = shuffle_data(orig_seqs=ground_truth_seqs, clusters=clusters)

    test_ground_truth = ground_truth_seqs[:test_size]
    test_clusters = clusters[:test_size]

    train_val_ground_truth = ground_truth_seqs[test_size:]
    train_val_clusters = clusters[test_size:]

    print('len of test_ground_truth:', len(test_ground_truth))
    print('len of test_clusters:', len(test_clusters))
    
    test_cpred, _ = generate_cpred_data(orig_seqs=test_ground_truth, clusters=test_clusters, size = ub, lb=lb)
    train_val_cpred, _ = generate_cpred_data(orig_seqs=train_val_ground_truth, clusters=train_val_clusters, size = ub, lb=lb)

    print('len of test_cpred:', len(test_cpred))
    print('len of train_val_cpred:', len(train_val_cpred))
    
    num = count_obs_seqs(test_cpred)
    print('num:', num)  
    
    n = len(train_val_cpred)
    split_ratio = 0.9
    train_data = train_val_cpred[:int(n*split_ratio)]
    val_data = train_val_cpred[int(n*split_ratio):]
    
    print('len of train_data:', len(train_data))
    print('len of val_data:', len(val_data))

    test_dist = get_cluster_size_distribution_cpred(test_cpred)
    train_dist = get_cluster_size_distribution_cpred(train_data)
    val_dist = get_cluster_size_distribution_cpred(val_data)

    create_folder(f'data/finetuning_data_{cluster_case}')

    if save_flag:
        write_data_to_file(f'data/finetuning_data_{cluster_case}/cpred_test_{cluster_case}.txt', test_cpred)
        write_data_to_file(f'data/finetuning_data_{cluster_case}/cpred_train_{cluster_case}.txt', train_data)
        write_data_to_file(f'data/finetuning_data_{cluster_case}/cpred_val_{cluster_case}.txt', val_data)
        print('DONE')
