# Define the two lists of strings
import sys
from utils.data_functions import load_data_from_file, write_data_to_file

def get_ground_truth_list(data):

    ground_truth_list = []
    for data_example in data:
        ground_truth = data_example.split(':')[1]
        ground_truth_list.append(ground_truth)

    return ground_truth_list


write = False
cluster_case = 'SC'

train_val_name = f"data/train_val_data_{cluster_case}.txt" 
test_name = f'data/starcode_test_cpred_data_{cluster_case}.txt'

train_val = load_data_from_file(train_val_name)
test = load_data_from_file(test_name)

train_val_ground_truth = get_ground_truth_list(train_val)

print('len train_val')
print(len(train_val_ground_truth))
test_ground_truth = get_ground_truth_list(test)
print('len test')
print(len(test_ground_truth))
#sys.exit()

# Convert lists to sets
train_val_set = set(train_val_ground_truth)
test_set = set(test_ground_truth)

# Find duplicates using intersection
duplicates = test_set.intersection(train_val_set)

# Print the duplicates
print('len(duplicates): ', len(duplicates)) 
filtered_train_val_ground_truth = [line for line in train_val if line.split(':')[1] not in duplicates]

print('len(filtered_train_val_ground_truth)')
print(len(filtered_train_val_ground_truth))

# check again
duplicates = test_set.intersection(set(filtered_train_val_ground_truth))
print("Duplicates:", duplicates)
print('len(duplicates): ', len(duplicates))

if write:
    write_data_to_file(train_val_name,filtered_train_val_ground_truth)