import json
import os
import random


aggregate_file = "/path/to/single/document/generated/data.jsonl"
output_file = "/path/to/output/file.jsonl"


## Multidocument concatenation algorithm
def process_group(lines):
    conversations = []
    past_specific_questions = []

    for i, line in enumerate(lines):
        data = json.loads(line)
        if i == 0:
            conv_1 = data["conversations"]
        elif i == 1:
            conv_2 = data["conversations"]
        elif i == 2: 
            conv_3 = data["conversations"]
        elif i == 3:
            conv_4 = data["conversations"]

    conversations.extend(conv_1[:2])

    random_hier_questions_1 = 5
    random_specific_questions_1 = 5
    random_hier_questions_2 = 5
    random_specific_questions_2 = 5
    random_hier_questions_3 = 5
    random_specific_questions_3 = 5
    random_hier_questions_4 = 5
    random_specific_questions_4 = 5
    random_follow_up_hier_questions = 3 
    random_follow_up_specific_questions = 9

    conversations.extend(conv_1[2:2*random_hier_questions_1+2])

    specific_indices = list(range(26, 75))
    specific_questions_1 = random.sample(specific_indices, random_specific_questions_1)
    for specific_1_idx in specific_questions_1:
        conversations.append(conv_1[specific_1_idx * 2])
        conversations.append(conv_1[specific_1_idx * 2 + 1])

    remaining_specific_questions_1 = []
    for i in range(52, 150):
        if (i - 52) // 2 + 26 not in specific_questions_1:
            remaining_specific_questions_1.append(conv_1[i])
    
    follow_up_specific_questions = remaining_specific_questions_1

    conversations.extend(conv_2[:2])

    conversations.extend(conv_2[2:2*random_hier_questions_2+2])
    specific_questions_2 = random.sample(specific_indices, random_specific_questions_2)
    for specific_2_idx in specific_questions_2:
        conversations.append(conv_2[specific_2_idx * 2])
        conversations.append(conv_2[specific_2_idx * 2 + 1])

    remaining_specific_questions_2 = []
    for i in range(52, 150):
        if (i - 52) // 2 + 26 not in specific_questions_2:
            remaining_specific_questions_2.append(conv_2[i])

    conversations.extend(conv_1[2*random_hier_questions_1+2:2*random_hier_questions_1+2+2*random_follow_up_hier_questions])
    
    last_conv_1_hier_idx = 2*random_hier_questions_1+2+2*random_follow_up_hier_questions
    last_conv_2_hier_idx = 2*random_hier_questions_2+2

    x = len(follow_up_specific_questions) // 2
    if x >= random_follow_up_specific_questions:
        selected_indices = random.sample(range(x), random_follow_up_specific_questions)
        selected_follow_up = []
        for y in selected_indices:
            selected_follow_up.extend(follow_up_specific_questions[2*y:2*y+2])
        conversations.extend(selected_follow_up)
        
        follow_up_specific_questions = [q for i, q in enumerate(follow_up_specific_questions) 
                                        if i // 2 not in selected_indices]

    follow_up_specific_questions.extend(remaining_specific_questions_2)

    conversations.extend(conv_3[:2])

    conversations.extend(conv_3[2:2*random_hier_questions_3+2])
    specific_questions_3 = random.sample(specific_indices, random_specific_questions_3)
    for specific_3_idx in specific_questions_3:
        conversations.append(conv_3[specific_3_idx * 2])
        conversations.append(conv_3[specific_3_idx * 2 + 1])

    last_conv_3_hier_idx = 2*random_hier_questions_3+2
    remaining_specific_questions_3 = []
    for i in range(52, 150):
        if (i - 52) // 2 + 26 not in specific_questions_3:
            remaining_specific_questions_3.append(conv_3[i])

    if random.random() > 0.6 and last_conv_1_hier_idx+2*random_follow_up_hier_questions <= 52:
        conversations.extend(conv_1[last_conv_1_hier_idx:last_conv_1_hier_idx+2*random_follow_up_hier_questions])
        last_conv_1_hier_idx += 2*random_follow_up_hier_questions
 
    if random.random() > 0.6 and last_conv_2_hier_idx+2*random_follow_up_hier_questions <= 52:
        conversations.extend(conv_2[last_conv_2_hier_idx:last_conv_2_hier_idx+2*random_follow_up_hier_questions])
        last_conv_2_hier_idx += 2*random_follow_up_hier_questions

    x = len(follow_up_specific_questions) // 2
    if x >= random_follow_up_specific_questions:
        selected_indices = random.sample(range(x), random_follow_up_specific_questions)
        selected_follow_up = []
        for y in selected_indices:
            selected_follow_up.extend(follow_up_specific_questions[2*y:2*y+2])
        conversations.extend(selected_follow_up)
        
        follow_up_specific_questions = [q for i, q in enumerate(follow_up_specific_questions) 
                                        if i // 2 not in selected_indices]

    follow_up_specific_questions.extend(remaining_specific_questions_3)

    conversations.extend(conv_4[:2])

    conversations.extend(conv_4[2:2*random_hier_questions_4+2])
    last_conv_4_hier_idx = 2*random_hier_questions_4+2
    specific_questions_4 = random.sample(specific_indices, random_specific_questions_4)
    for specific_4_idx in specific_questions_4:
        conversations.append(conv_4[specific_4_idx * 2])
        conversations.append(conv_4[specific_4_idx * 2 + 1])

    remaining_specific_questions_4 = []
    for i in range(52, 150):
        if (i - 52) // 2 + 26 not in specific_questions_4:
            remaining_specific_questions_4.append(conv_4[i])

    if random.random() > 0.6 and last_conv_1_hier_idx+2*random_follow_up_hier_questions <= 52:
        conversations.extend(conv_1[last_conv_1_hier_idx:last_conv_1_hier_idx+2*random_follow_up_hier_questions])
        last_conv_1_hier_idx += 2*random_follow_up_hier_questions
 
    if random.random() > 0.6 and last_conv_2_hier_idx+2*random_follow_up_hier_questions <= 52:
        conversations.extend(conv_2[last_conv_2_hier_idx:last_conv_2_hier_idx+2*random_follow_up_hier_questions])
        last_conv_2_hier_idx += 2*random_follow_up_hier_questions

    if random.random() > 0.6 and last_conv_3_hier_idx+2*random_follow_up_hier_questions <= 52:
        conversations.extend(conv_3[last_conv_3_hier_idx:last_conv_3_hier_idx+2*random_follow_up_hier_questions])
        last_conv_3_hier_idx += 2*random_follow_up_hier_questions

    x = len(follow_up_specific_questions) // 2
    if x >= random_follow_up_specific_questions:
        selected_indices = random.sample(range(x), random_follow_up_specific_questions)
        selected_follow_up = []
        for y in selected_indices:
            selected_follow_up.extend(follow_up_specific_questions[2*y:2*y+2])
        conversations.extend(selected_follow_up)
        
        follow_up_specific_questions = [q for i, q in enumerate(follow_up_specific_questions) 
                                        if i // 2 not in selected_indices]

    follow_up_specific_questions.extend(remaining_specific_questions_4)

    return conversations




total_lines = sum(1 for _ in open(aggregate_file, 'r'))
print(f"Total lines in aggregate file: {total_lines}")

group_count = 0
total_groups = total_lines // 4

with open(aggregate_file, "r") as infile, open(output_file, "w") as outfile:
    all_lines = infile.readlines()
    
    for i in range(0, total_lines, 4):
        try:
            if i + 3 < total_lines:
                lines = [all_lines[i], all_lines[i+1], all_lines[i+2], all_lines[i+3]]
                
                processed_conversations = process_group(lines)
                json.dump({"conversations": processed_conversations}, outfile)
                outfile.write('\n')
                
                group_count += 1
                if group_count % 100 == 0:
                    print(f"Processed {group_count} groups")
        except Exception as e:
            print(e)
            continue

print(f"Total conversation groups processed and saved: {group_count}")
print(f"Expected groups based on line count: {total_groups}")
print(f"Percentage of file processed: {(group_count / total_groups) * 100:.2f}%")