#%%
import os
import requests
import zipfile

urls = [
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_train.json",
    "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json",
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json",
    "https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip"
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_train.json",
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_val.json",
    "https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_Rosetta_OCR_v0.2_test.json"
]

output_dir = "textvqa_data"
os.makedirs(output_dir, exist_ok=True)

def download_file(url, save_path):
    if os.path.exists(save_path):
        print(f"Already exists: {save_path}")
        return
    print(f"Downloading {url}")
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(save_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)
    print(f"Saved to {save_path}")

def extract_zip(zip_path, extract_to):
    print(f"Extracting {zip_path}")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted to {extract_to}")

for url in urls:
    filename = url.split("/")[-1]
    save_path = os.path.join(output_dir, filename)
    download_file(url, save_path)
    if filename.endswith(".zip"):
        extract_zip(save_path, output_dir)


#%%
import json
import os
from collections import Counter
import matplotlib.pyplot as plt

json_files = [
    "textvqa_data/TextVQA_Rosetta_OCR_v0.2_test.json",
    "textvqa_data/TextVQA_Rosetta_OCR_v0.2_train.json",
    "textvqa_data/TextVQA_Rosetta_OCR_v0.2_val.json",
]

bbox_count_per_image = Counter()

for file_path in json_files:
    if not os.path.exists(file_path):
        continue
    with open(file_path, 'r') as f:
        content = json.load(f)
        for entry in content.get("data", []):
            image_id = entry["image_id"]
            num_boxes = len(entry.get("ocr_info", []))
            bbox_count_per_image[image_id] += num_boxes

count_distribution = Counter()
for count in bbox_count_per_image.values():
    count_distribution[count] += 1
total_images = len(bbox_count_per_image)
print(f"Total images: {total_images}")
plt.figure(figsize=(8, 6))
bars = plt.bar(
    [str(k) for k in sorted(count_distribution)],
    [count_distribution[k] for k in sorted(count_distribution)],
)

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, height + 1, str(height),
             ha='center', va='bottom', fontsize=12)

plt.xlabel("Number of OCR Bounding Boxes per Image (<=3)")
plt.ylabel("Number of Images")
plt.title("Image Count by Number of OCR Bounding Boxes (≤3)")
plt.tight_layout()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()
