import json
import os
from collections import defaultdict

root_dir = './filtered_data/'

total_samples = 0
total_files = 0
total_lines = 0
positive_samples = 0
negative_samples = 0
samples = {}

for root, dirs, files in os.walk(root_dir):
    for file in files:
        if file.endswith('.json'):
            try:
                file_path = os.path.join(root, file)
                total_samples += 1
                total_files += 1
                if 'negative' in root:
                    negative_samples += 1
                if 'positive' in root:
                    positive_samples += 1
                if root not in samples:
                    samples[root] = 0
                samples[root] += 1
            except Exception as e:
                print(f'Error processing file {file_path}: {e}')
        if file.endswith('.jsonl'):
            file_path = os.path.join(root, file)
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            file_len = 0
            for line in lines:
                try:
                    data = json.loads(line)
                except Exception as e:
                    print(f'Error processing line in file {file_path}: {e}')
                file_len += len(data['solution'])
            total_lines += file_len
            total_samples += file_len
            if root not in samples:
                samples[root] = 0
            samples[root] += file_len
            if 'negative' in root:
                negative_samples += file_len
            if 'positive' in root:
                positive_samples += file_len

        if total_samples % 100 == 0:
            # print(f'Processed {total_samples} samples...')
            pass

print('\n'+'='*20+' Summary '+'='*20+'\n')
print(f'Total samples: {total_samples}')
print(f'Total .json files: {total_files}')
print(f'Total lines in .jsonl files: {total_lines}')
print(
    f'Positive samples: {positive_samples}, Negative samples: {negative_samples}'
)
print('\n'+'='*10+' Samples breakdown by directory '+'='*10+'\n')
for dir, count in samples.items():
    print(f'{os.path.relpath(dir, root_dir)}: {count}')
