"""
SynthTIGER
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import os
import json
import numpy as np
from PIL import Image

from synthtiger import components, layers, templates, utils


class Multiline(templates.Template):
    def __init__(self, config=None):
        if config is None:
            config = {}

        self.count = config.get("count", 1)
        self.corpus = components.BaseCorpus(**config.get("corpus", {}))
        self.font = components.BaseFont(**config.get("font", {}))
        self.default_font = "resources/multilingual_mixed_fonts/NotoSans-VariableFont_wdth,wght.ttf"
        self.font_list = [path for paths in self.font._paths for path in paths]
        self.samples_path = config.get("samples", {}).get("paths", [None])[0]
        self.linked_words_path = config.get("linked_words", {}).get("paths", [None])[0]
        self.new_samples = self.generate_new_samples()
        self.new_samples_dict = self._load_samples()
        self.color = components.RGB(**config.get("color", {}))
        self.layout = components.FlowLayout(**config.get("layout", {}))
        self.sampled_label = []
    
    def generate_new_samples(self):
        if not self.samples_path or not os.path.exists(self.samples_path):
            print(f"Error: samples.json not found: {self.samples_path}")
            return []
        
        if not self.linked_words_path or not os.path.exists(self.linked_words_path):
            print(f"Error: linked_words.txt not found: {self.linked_words_path}")
            return []

        with open(self.samples_path, "r", encoding="utf-8") as f:
            samples_data = json.load(f)

        linked_dict = {}
        with open(self.linked_words_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if "||" not in line:
                    continue
                try:
                    original_text, transformed_text = line.split("||", 1)
                    linked_dict[original_text] = transformed_text
                except ValueError:
                    continue

        new_samples = []

        for sample in samples_data:
            raw_word = sample["raw_word"]
            raw_text = sample["raw_text"]
            trans_text = linked_dict.get(raw_text, None)
            if not trans_text:
                print(f"Warning: transformed text not found for raw_text: {raw_text}")
                continue

            start_idx = raw_text.find(raw_word)
            if start_idx == -1:
                print(f"Error: raw_word not found in raw_text: {raw_word} in {raw_text}")
                continue

            trans_word = trans_text[start_idx: start_idx + len(raw_word)]    

            new_samples.append({
                "raw_word": raw_word,
                "raw_text": raw_text,
                "trans_word": trans_word,
                "trans_text": trans_text
            })

        print(f"New samples generated: {len(new_samples)}")
        return new_samples
    
    def _load_samples(self):
        samples_dict = {}
        for sample in self.new_samples:
            trans_text = sample["trans_text"]
            samples_dict[trans_text] = sample
        return samples_dict

    def is_char_supported(self, font_path, char):
        txt_path = f"{os.path.splitext(font_path)[0]}.txt"
        if not os.path.exists(txt_path):
            return False
        
        with open(txt_path, "r", encoding="utf-8") as file:
            glyphs = file.read()
        
        return char in glyphs

    def generate(self):  
        label = self.corpus.data(self.corpus.sample())
        while label in self.sampled_label:
            label = self.corpus.data(self.corpus.sample())
        self.sampled_label.append(label)
        color = self.color.data(self.color.sample())

        sample_data = self.new_samples_dict.get(label, None)
        if not sample_data:
            print(f"Error: sample data not found for label: {label}")
            return None

        raw_word = sample_data["raw_word"]
        raw_text = sample_data["raw_text"]
        trans_word = sample_data["trans_word"]
        trans_text = sample_data["trans_text"]      

        texts = label.split()
        text_layers = []

        for text in texts:
            chars = utils.split_text(text, reorder=False)
            char_layers = []
            baseline_ref_char = "x"
            baseline_font_path = "resources/multilingual_mixed_fonts/NotoSans-VariableFont_wdth,wght.ttf"
            baseline_font_meta = self.font.sample({"path": baseline_font_path})
            baseline_char_layer = layers.TextLayer(baseline_ref_char, **baseline_font_meta)
            baseline_bbox = baseline_char_layer.bbox
            baseline_y = (baseline_bbox[1] + baseline_bbox[3]) / 2

            spacing = -1
            prev_char_right = 0

            for char in chars:
                font = None
                for font_option in self.font_list:
                    if self.is_char_supported(font_option, char):
                        font = font_option
                        break
                if font is None:
                    print(f"Error: unsupported char '{char}' in word '{text}'")
                    return None

                font_meta = self.font.sample({"path": font})
                font_meta["size"] = 32
                char_layer = layers.TextLayer(char, **font_meta)

                char_bbox = char_layer.bbox
                char_middle = (char_bbox[1] + char_bbox[3]) / 2
                y_offset = baseline_y - char_middle
                char_layer.topleft = (prev_char_right + spacing, baseline_y - (char_bbox[3] - char_bbox[1]) / 2)
                prev_char_right = char_layer.right
                char_layers.append(char_layer)

            text_layer = layers.Group(char_layers).merge()
            text_layers.append(text_layer)

        PADDING_LEFT = 30
        PADDING_RIGHT = 30
        PADDING_TOP = 10
        PADDING_BOTTOM = 10

        text_group = layers.Group(text_layers)
        self.layout.apply(text_group)
        text_group.topleft = (PADDING_LEFT, PADDING_TOP)
        bg_width = text_group.size[0] + PADDING_LEFT + PADDING_RIGHT
        bg_height = text_group.size[1] + PADDING_TOP + PADDING_BOTTOM
        bg_layer = layers.RectLayer((bg_width, bg_height), (255, 255, 255, 255))
        bg_layer.topleft = (0, 0)
        image = (text_group + bg_layer).output()

        data = {
            "image": image,
            "label": label,
            "raw_text": raw_text,
            "trans_word": trans_word,
            "raw_word": raw_word,
        }

        print(f"Sampled count: {len(self.sampled_label)}")

        return data

    def init_save(self, root):
        os.makedirs(root, exist_ok=True)
        gt_path = os.path.join(root, "gt.txt")
        self.gt_file = open(gt_path, "w", encoding="utf-8")
        self.new_samples_path = os.path.join(root, "samples.json")
        self.new_samples_list = []

    def save(self, root, data, idx):
        image = data["image"]
        label = data["label"]
        raw_text = data["raw_text"]
        trans_word = data["trans_word"]
        raw_word = data["raw_word"]

        shard = str(idx // 10000)
        image_key = os.path.join("images", shard, f"{idx}.jpg")
        image_path = os.path.join(root, image_key)

        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        image = Image.fromarray(image[..., :3].astype(np.uint8))
        image.save(image_path, quality=95)

        self.gt_file.write(f"{image_key}\t{raw_word}\t{trans_word}\t{raw_text}\t{label}\n")

        self.new_samples_list.append({
            "raw_word": raw_word,
            "raw_text": raw_text,
            "trans_word": trans_word,
            "trans_text": label
        })

    def end_save(self, root):
        self.gt_file.close()
        with open(self.new_samples_path, "w", encoding="utf-8") as f:
            json.dump(self.new_samples_list, f, ensure_ascii=False, indent=4)
        print(f"samples.json saved at {self.new_samples_path}")