"""
SynthTIGER
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import os
import json
import numpy as np
from PIL import Image
import random

from synthtiger import components, layers, templates, utils

class Multiline(templates.Template):
    def __init__(self, config=None):
        if config is None:
            config = {}

        self.count = config.get("count", 1)
        self.corpus = components.BaseCorpus(**config.get("corpus", {}))
        self.font = components.BaseFont(**config.get("font", {}))
        self.default_font = "resources/multilingual_mixed_fonts/NotoSans-VariableFont_wdth,wght.ttf"
        self.font_list = [path for paths in self.font._paths for path in paths]
        self.samples_path = config.get("samples", {}).get("paths", [None])[0]
        self.samples_data = self._load_samples()
        self.color = components.RGB(**config.get("color", {}))
        self.layout = components.FlowLayout(**config.get("layout", {}))
        self.glyph_coords = []
        self.corpus_data = self._load_corpus_data()
        self._shuffle_corpus_data()
        self.current_index = 0

    def _load_samples(self):
        samples_dict = {}
        if not self.samples_path or not os.path.exists(self.samples_path):
            return samples_dict
        with open(self.samples_path, "r", encoding="utf-8") as file:
            data = json.load(file)
            for sample in data:
                safe_text = sample["safe_text"]
                samples_dict[safe_text] = sample
        return samples_dict

    def _load_corpus_data(self):
        corpus_data = []
        for path in self.corpus.paths:
            with open(path, "r", encoding="utf-8") as file:
                for line in file:
                    text = line.strip()
                    if self.corpus._check_length(text) and self.corpus._check_charset(text):
                        corpus_data.append(text)
        return corpus_data

    def _shuffle_corpus_data(self):
        self.sampled_labels = random.sample(self.corpus_data, len(self.corpus_data))
        self.current_index = 0

    def _get_next_label(self):
        if self.current_index >= len(self.sampled_labels):
            self._shuffle_corpus_data()
        label = self.sampled_labels[self.current_index]
        self.current_index += 1
        return label

    def generate(self):
        label = self._get_next_label()
        color = self.color.data(self.color.sample())
        sample_data = self.samples_data.get(label, None)
        if not sample_data:
            return None

        safe_word = sample_data["safe_word"]
        safe_text = sample_data["safe_text"]

        text_layers = []
        char_layers_list = []
        initial_char_bboxes_list = []
        char_list = []

        texts = label.split()
        prev_line_bottom = 0
        line_x_offset = 0
        line_y_offset = 0

        for text in texts:
            chars = utils.split_text(text, reorder=False)
            char_list.extend(chars)
            char_layers = []
            char_bboxes = []

            baseline_ref_char = "x"
            baseline_font_path = self.default_font
            baseline_font_meta = self.font.sample({"path": baseline_font_path})
            baseline_char_layer = layers.TextLayer(baseline_ref_char, **baseline_font_meta)
            baseline_bbox = baseline_char_layer.bbox
            baseline_y = (baseline_bbox[1] + baseline_bbox[3]) / 2

            spacing = -1
            prev_char_right = line_x_offset

            for char in chars:
                font = self.default_font
                font_meta = self.font.sample({"path": font})
                font_meta["size"] = 32
                char_layer = layers.TextLayer(char, **font_meta)
                initial_char_bbox = char_layer.bbox.copy()
                char_bboxes.append(initial_char_bbox)
                char_bbox = char_layer.bbox
                char_middle = (char_bbox[1] + char_bbox[3]) / 2
                y_offset = baseline_y - char_middle
                char_layer.topleft = (prev_char_right + spacing, prev_line_bottom + y_offset)
                prev_char_right = char_layer.right
                char_layers.append(char_layer)

            text_layer = layers.Group(char_layers).merge()
            text_layers.append(text_layer)
            char_layers_list.append(char_layers)
            initial_char_bboxes_list.append(char_bboxes)

        text_group = layers.Group(text_layers)
        before_bboxes = [text_layer.bbox for text_layer in text_group.layers]
        self.layout.apply(text_group)
        after_bboxes = [text_layer.bbox for text_layer in text_group.layers]

        final_char_bboxes = []
        char_idx = 0

        for before_bbox, after_bbox, char_layers, initial_bboxes in zip(before_bboxes, after_bboxes, char_layers_list, initial_char_bboxes_list):
            dx, dy = after_bbox[0] - before_bbox[0], after_bbox[1] - before_bbox[1]

            for char_layer, initial_bbox in zip(char_layers, initial_bboxes):
                char_bbox = char_layer.bbox
                updated_width = float(initial_bbox[2] - initial_bbox[0])
                updated_height = float(initial_bbox[3] - initial_bbox[1])
                updated_char_bbox = [
                    float(char_bbox[0] + dx),
                    float(char_bbox[1] + dy),
                    float(updated_width),
                    float(updated_height)
                ]
                final_char_bboxes.append({
                    "character": char_list[char_idx],
                    "bbox": updated_char_bbox
                })
                char_idx += 1

        PADDING_LEFT = 30
        PADDING_TOP = 10
        text_group.topleft = (PADDING_LEFT, PADDING_TOP)
        bg_width = text_group.size[0] + PADDING_LEFT * 2
        bg_height = text_group.size[1] + PADDING_TOP * 2
        bg_layer = layers.RectLayer((bg_width, bg_height), (255, 255, 255, 255))
        bg_layer.topleft = (0, 0)
        image = (text_group + bg_layer).output()

        rendered_bbox = text_group.bbox
        sentence_bbox = {
            "label": safe_text,
            "bbox": [
                float(rendered_bbox[0]),
                float(rendered_bbox[1]),
                float(rendered_bbox[2] - rendered_bbox[0]),
                float(rendered_bbox[3] - rendered_bbox[1])
            ]
        }

        final_padded_char_bboxes = []
        for char_bbox_data in final_char_bboxes:
            char_bbox = char_bbox_data["bbox"]
            padded_bbox = [
                float(char_bbox[0]) + PADDING_LEFT,
                float(char_bbox[1]) + PADDING_TOP,
                float(char_bbox[2]),
                float(char_bbox[3])
            ]
            final_padded_char_bboxes.append({
                "character": char_bbox_data["character"],
                "bbox": padded_bbox
            })

        data = {
            "image": image,
            "safe_word": safe_word,
            "safe_text": safe_text,
            "text_bboxes": [sentence_bbox],
            "characters": final_padded_char_bboxes,
        }

        return data

    def init_save(self, root):
        os.makedirs(root, exist_ok=True)
        self.gt_file = open(os.path.join(root, "gt.txt"), "w", encoding="utf-8")
        self.coords_file = open(os.path.join(root, "coords.txt"), "w", encoding="utf-8")
        self.glyph_coords = []

    def save(self, root, data, idx):
        image = data["image"]
        safe_word = data["safe_word"]
        safe_text = data["safe_text"]
        sentence_bbox = data["text_bboxes"][0]
        character_bboxes = data["characters"]

        shard = str(idx // 10000)
        image_key = os.path.join("images", shard, f"{idx}.jpg")
        image_path = os.path.join(root, image_key)

        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        image = Image.fromarray(image[..., :3].astype(np.uint8))
        image.save(image_path, quality=95)

        word_data = {
            "idx": idx,
            "word_bbox": sentence_bbox,
            "characters": character_bboxes
        }
        self.glyph_coords.append(word_data)

        self.gt_file.write(f"{image_key}\t{safe_word}\t{safe_text}\n")
        self.coords_file.write(
            f"{image_key}\t{sentence_bbox['bbox'][0]},{sentence_bbox['bbox'][1]},{sentence_bbox['bbox'][2]},{sentence_bbox['bbox'][3]}\n"
        )

    def end_save(self, root):
        self.gt_file.close()
        self.coords_file.close()
        with open(os.path.join(root, "glyph_coords_4.json"), "w", encoding="utf-8") as json_file:
            json.dump(self.glyph_coords, json_file, ensure_ascii=False, indent=4)