import json
import random
import shutil
from PIL import Image
from networkx import constraint
from tqdm import tqdm
# from utils import load_kg_triples, load_entity_image_maps
from collections import defaultdict
from graphviz import Digraph

import data


global_id = 0

from PIL import Image, ImageOps

def add_blank_area_to_bottom(image_path, output_path, blank_height):
    original_image = Image.open(image_path)
    original_width, original_height = original_image.size
    new_image = Image.new('RGB', (original_width, original_height + blank_height), color='white')
    new_image.paste(original_image, (0, 0))
    new_image.save(output_path)
    # print(f"New image saved to {output_path}")


def draw_kg(nodes, edges, images, task_type):
    assert len(nodes) == len(images)
    kg = Digraph(
        format='png',
        encoding='utf-8',
        engine=random.choice(['dot', 'circo']),
        graph_attr={
            'rankdir': 'LR',
            'nodesep': '0.8',
            'ranksep': '0.6',
            'fontname': 'SimSun'
        },
        node_attr={
            'shape': 'rectangle',
            'width': '1.5',
            'height': '1.5',
            'fixedsize': 'true',
            'fontsize': '12',
            'fontcolor': 'black',
            'fontname': 'Times New Roman',  # 全局节点字体
            'color': 'black',
            'style': 'filled',
            'fillcolor': 'white',
            'labelloc': 'b',
            'imagescale': 'true'   # 防止图片变形
        },
        edge_attr={
            'fontsize': '12',
            'fontcolor': 'black',
            'fontname': 'Times New Roman',  # 全局边标签字体
            'color': 'black',
            'style': 'dashed'
        }
    )
    num_node = len(nodes)
    num_edges = len(edges)
    new_images = []
    global global_id
    # print(num_node)
    for i in range(num_node):
        if "[MASK]" in nodes[i]:
            # new_images.append("random_noise.png")
            noise_image = "random_noise.png"
            new_path = "used_image/new_{}_{}.png".format(task_type, i)
            
            if "VSEM/dataset/images/" in images[i]:
                images[i] = images[i] + '.jpg'
            add_blank_area_to_bottom(noise_image, new_path, blank_height=80)
            new_images.append(new_path)
        else:
            new_path = "used_image/new_{}_{}.png".format(task_type, i)
            
            if "VSEM/dataset/images/" in images[i]:
                images[i] = images[i] + '.jpg'
            add_blank_area_to_bottom(images[i], new_path, blank_height=80)
            new_images.append(new_path)
    # kg.attr('node', shape='box', style='filled', margin='0.5')
    for i in range(num_node):
        kg.node(
            name=nodes[i],
            label=nodes[i],
            image=new_images[i]
        )
    for i in range(num_edges):
        h, t, r = edges[i]
        kg.edge(
            h,
            t,
            label=r,
            # weight='2'
        )
    
    kg.render(
        filename='task{}_count_{}'.format(task_type, global_id),
        directory='generate/task{}'.format(task_type),
        cleanup=True
    )
    src_path = 'task{}_count_{}'.format(task_type, global_id)
    tgt_path = 'generate/images/task{}/task{}_count_{}.png'.format(task_type, task_type, global_id)
    # shutil.move(src_path, tgt_path)
    global_id += 1
    return tgt_path


def load_image_map(dataset):
    image_map = json.load(open("dataset/subgraph/{}_entity_image_map.json".format(dataset), "r"))
    return image_map
    

if __name__ == "__main__":
    processed_data = json.load(open("dataset/subgraph/processed_instances.json", "r"))
    image_map_full = {}
    
    error_count = 0
    for dataset in ['mkgy', 'fb15k237', 'vsem']:
        image_map = load_image_map(dataset)
        for ent in image_map:
            image_map_full[ent] = image_map[ent]
    count = 0
    for i in [7]:
        instance_with_images = []
        for instance in tqdm(processed_data[str(i)]):
            nodes = instance["entity"]
            edges = instance["triple"]
            source = instance["source"]
            images = []
            for node in nodes:
                image_candidates = image_map_full.get(node, "")
                if image_candidates != "":
                    images.append(random.choice(image_candidates))
                else:
                    images.append("")
                    count += 1
            try:
                image_path = draw_kg(nodes, edges, images, i)
                instance["image"] = image_path
                instance_with_images.append(instance)
            except Exception as e:
                print(f"Error processing instance.", e)
                error_count += 1
        json.dump(instance_with_images, open("dataset/subgraph/processed_instances_with_images_task{}.json".format(i), "w"), ensure_ascii=False)
    print(error_count)