import os
import json
import random

random.seed(2025)

def read_json(path):
    with open(path, 'r', encoding='utf-8') as fr:
        return json.load(fr)

def read_jsonl(path):
    datas = []
    with open(path, 'r', encoding='utf-8') as frs:
        for fr in frs:
            datas.append(json.loads(fr))
    return datas

def write_json(datas, path):
    with open(path, 'w', encoding='utf-8') as fr:
        json.dump(datas, fr, indent=4)

def count_facts(fact_datas):
    fact_counts = 0
    for id, facts in fact_datas.items():
        fact_counts += len(facts)
    return fact_counts

def count_modality(fact_datas):
    text_counts = 0
    visual_counts = 0
    for id, facts in fact_datas.items():
        for fact in facts:
            if fact['source'] == 'text':
                text_counts += 1
            else:
                visual_counts += 1
    return text_counts, visual_counts

def count_page_facts(fact_datas):
    page_facts_counts = []
    page_text_facts_counts = []
    page_visual_facts_counts = []

    for id, facts in fact_datas.items():
        page_facts_counts.append(len(facts))
        text_counts, visual_counts = count_modality({id: facts})
        page_text_facts_counts.append(text_counts)
        page_visual_facts_counts.append(visual_counts)

    return page_facts_counts, page_text_facts_counts, page_visual_facts_counts

def count_topics(topic_datas):
    return len(topic_datas)

def count_page_topics(topic_datas):
    fact_across_page = []
    facts_by_topic = []

    text_facts_by_topic = []
    visual_facts_by_topic = []
    for topic in topic_datas:
        fact_page = [int(ids.split('::')[1].split('_')[-1].split('.')[0]) for ids in topic['fact_ids']]
        fact_page = sorted(list(set(fact_page)))
        fact_across_page.append(fact_page)
        facts_by_topic.append(len(topic['fact_ids']))

        text_facts = 0
        visual_facts = 0
        for fact in topic['facts']:
            source = fact['source']
            if source == 'text':
                text_facts += 1
            else:
                visual_facts += 1
        text_facts_by_topic.append(text_facts)
        visual_facts_by_topic.append(visual_facts)

    return fact_across_page, facts_by_topic, text_facts_by_topic, visual_facts_by_topic

def merge_topics(topic_datas):
    facts = []
    fact_ids = []
    for topic in topic_datas:
        fact_id = topic['fact_ids']
        fact_ids.extend(fact_id)
        cluster_id = topic['cluster_id']

        for fact in topic['facts']:
            fact['cluster_id'] = cluster_id
            facts.append(json.dumps(fact))
    fact_ids = list(set(fact_ids))
    facts = list(set(facts))
    datas = [{'fact_ids': fact_ids, 'facts': [json.loads(fact) for fact in facts]}]
    return datas

fact_path = './facts.json'
topic_path = './clusters.jsonl'
output_path = './sample_topics.json'

fact_datas = read_json(fact_path)
fact_counts = count_facts(fact_datas)

text_counts, visual_counts = count_modality(fact_datas)
print(f'fact_counts: {fact_counts}')
print(f'text_counts: {text_counts}')
print(f'visual_counts: {visual_counts}')
print()

page_facts_counts, page_text_facts_counts, page_visual_facts_counts = count_page_facts(fact_datas)
print(f'page fact min: {min(page_facts_counts)}')
print(f'page fact max: {max(page_facts_counts)}')
print(f'page fact mean: {(sum(page_facts_counts) / len(page_facts_counts)):.2f}')
print()

print(f'page text fact min: {min(page_text_facts_counts)}')
print(f'page text fact max: {max(page_text_facts_counts)}')
print(f'page text fact mean: {(sum(page_text_facts_counts) / len(page_text_facts_counts)):.2f}')
print()

print(f'page visual fact min: {min(page_visual_facts_counts)}')
print(f'page visual fact max: {max(page_visual_facts_counts)}')
print(f'page visual fact mean: {(sum(page_visual_facts_counts) / len(page_visual_facts_counts)):.2f}')
print()

topic_datas = read_jsonl(topic_path)

fact_across_page, facts_by_topic, text_facts_by_topic, visual_facts_by_topic = count_page_topics(topic_datas)

topic_fact_diff_page = [len(f) for f in fact_across_page]
print(f'topic fact in different page min: {min(topic_fact_diff_page)}')
print(f'topic fact in different page max: {max(topic_fact_diff_page)}')
print(f'topic fact in different page mean: {(sum(topic_fact_diff_page) / len(topic_fact_diff_page)):.2f}')
print()

topic_length = count_topics(topic_datas)
print(f'topic counts: {topic_length}')

print(f'_' * 30)

# random combine topics
datas = {}
for topic_length in range(5, 21, 5):
    topic_key2facts = {}
    for _ in range(50):
        topics = random.sample(topic_datas, topic_length)
        cluster_id = sorted([t['cluster_id'] for t in topics])
        topics_key = ",".join(cluster_id)
        topic_key2facts[topics_key] = topics

        topics = merge_topics(topics)
        fact_across_page, facts_by_topic, text_facts_by_topic, visual_facts_by_topic = count_page_topics(topics)
        topic_fact_diff_page = [len(f) for f in fact_across_page]
        
    datas[topic_length] = topic_key2facts
    print(f'top_len: {topic_length}, sampled topic length: {len(topic_key2facts)}')

write_json(datas, output_path)