import os
import json
import pickle
from PIL import Image
from io import BytesIO
from collections import defaultdict


def get_image_size(bytes):
    with Image.open(BytesIO(bytes)) as img:
        return img.width, img.height


def find_valid():
    with open("pkl/data_test.pkl", 'rb') as file:
        data = pickle.load(file)

    valid_dict = defaultdict(list)
    for idx, item in enumerate(data):
        real_bytes = item['real_bytes']
        real_width, real_height = get_image_size(real_bytes)

        syn_bytes = item['syn_bytes']
        syn_width, syn_height = get_image_size(syn_bytes)

        condition_1 = 0.6 < real_width / real_height < 1.5
        condition_2 = 0.6 < syn_width / syn_height < 1.5
        condition_3 = item['triples'].count('\n') < 15
        if condition_1 and condition_2 and condition_3:
            valid_dict[item['category']].append({
                'idx': idx,
                'img_url': item['img_url'],
                'real_ratio': round(real_width / real_height, 2),
                'syn_ratio': round(syn_width / syn_height, 2),
                'num_triple': item['triples'].count('\n'),
            })

    for key, value in valid_dict.items():
        print(f"{key}: {len(value)}")

    os.makedirs("demo/diagram", exist_ok=True)
    with open("demo/diagram/valid_diagrams.json", 'w') as f:
        json.dump(valid_dict, f, indent=4)


def get_examples():
    with open("pkl/data_test.pkl", 'rb') as file:
        data = pickle.load(file)

    diagram_dir = "demo/diagram"
    image_dir = os.path.join(diagram_dir, "image")
    os.makedirs(image_dir, exist_ok=True)

    indices = [1353]
    triples = []
    for idx in indices:
        real_path = os.path.join(image_dir, f"{idx}_real.png")
        syn_path = os.path.join(image_dir, f"{idx}_syn.png")
        with open(real_path, "wb") as f:
            f.write(data[idx]['real_bytes'])
        with open(syn_path, "wb") as f:
            f.write(data[idx]['syn_bytes'])
        triples.append({
            'idx': idx,
            'triple': data[idx]['triples'],
            'QA': data[idx]['QAs'],
        })

    with open(os.path.join(diagram_dir, "triples_QAs.json"), "w") as f:
        json.dump(triples, f, indent=4)


if __name__ == '__main__':
    # find_valid()
    get_examples()

