"""
SynthTIGER
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import os

import cv2
import numpy as np
from PIL import Image, ImageDraw
import json
import random

from synthtiger import components, layers, templates, utils

shard = 4


def is_char_supported(font_path, char):

    txt_path = f"{os.path.splitext(font_path)[0]}.txt"  
    if not os.path.exists(txt_path):
        raise FileNotFoundError(f"Glyph file not found: {txt_path}")
    
    with open(txt_path, "r", encoding="utf-8") as file:
        glyphs = file.read()
    
    return char in glyphs


class SynthTiger(templates.Template):

    def __init__(self, config=None):
        if config is None:
            config = {}      

        self.coord_output = config.get("coord_output", True)
        self.mask_output = config.get("mask_output", True)
        self.glyph_coord_output = config.get("glyph_coord_output", True)
        self.glyph_mask_output = config.get("glyph_mask_output", True)
        self.vertical = config.get("vertical", False)
        self.quality = config.get("quality", [95, 95])
        self.visibility_check = config.get("visibility_check", False)

        self.default_font = "resources/multilingual_mixed_fonts/NotoSans-VariableFont_wdth,wght.ttf"

        self.corpus = components.BaseCorpus(**config.get("multilingual_corpus", {}))
        self.font = components.BaseFont(**config.get("multilingual_font", {}))
     
        self.font_list = [path for paths in self.font._paths for path in paths]
        
        self.linked_words_path = config.get("linked_words", {}).get("paths", [None])[0]
        self.linked_words_dict = self._load_linked_words()  

        self.color = components.RGB(**config.get("color", {}))
        self.layout = components.FlowLayout(**config.get("layout", {}))
        self.pad = components.Switch(components.Pad(), **config.get("pad", {}))

    
        self.corpus_data = self._load_corpus_data()
        self.current_index = 0  # 현재 위치 추적

        self._shuffle_corpus_data()        

    def _load_corpus_data(self):
        """ Corpus의 모든 데이터를 한 번씩 가져와 리스트로 저장 """
        corpus_data = []
        for path in self.corpus.paths:
            with open(path, "r", encoding="utf-8") as file:
                for line in file:
                    text = line.strip()
                    if self.corpus._check_length(text) and self.corpus._check_charset(text):
                        corpus_data.append(text)
        # debugging
        # print(f"courpus data 개수: {len(corpus_data)}")
        return corpus_data

    def _shuffle_corpus_data(self):
        """ Corpus 데이터를 랜덤하게 섞음 """
        self.sampled_labels = random.sample(self.corpus_data, len(self.corpus_data))
        self.current_index = 0  

    def _get_next_label(self):
        """ 중복 없는 랜덤 샘플링 적용 """
        if self.current_index >= len(self.sampled_labels):
            self._shuffle_corpus_data() 

        label = self.sampled_labels[self.current_index]
        self.current_index += 1
        return label


    def _load_linked_words(self):
        """ linked_words.txt 파일을 로드하여 {변형된 단어: 원래 단어} 딕셔너리 생성 """
        linked_words_dict = {}

        
        if not self.linked_words_path:
            print(" Error: linked_words_path 설정이 없음!")
            return linked_words_dict
        elif not os.path.exists(self.linked_words_path):
            print(f"Error: linked_words.txt 파일이 존재하지 않습니다! ({self.linked_words_path})")
            return linked_words_dict  

        print(f"linked_words.txt 파일 로드 중: {self.linked_words_path}")

        
        with open(self.linked_words_path, "r", encoding="utf-8") as file:
            for line in file:
                line = line.strip()

                if not line or "||" not in line:
                    print(f"Warning: 잘못된 형식의 줄 무시됨: {line}")
                    continue

                try:
                    original, transformed = line.split("||", 1)  
                    linked_words_dict[transformed] = original  
                except ValueError:
                    print(f"Warning: split 오류 발생, 무시된 줄: {line}")  

        print(f"linked_words.txt 로드 완료: {len(linked_words_dict)} 개의 단어 매핑됨")
        return linked_words_dict

    
   
    def generate(self):
        print("generate..")
        quality = np.random.randint(self.quality[0], self.quality[1] + 1)
        color = self.color.sample()
        result = self._generate_text(color)

        
        if result is None:
            print("❌ 데이터 생성 건너뜀!")
            return 0

        fg_image, label, bboxes, glyph_fg_image, glyph_bboxes = result
        fg_image_shape = fg_image.shape[:2][::-1]
        bg_image = self._generate_background(fg_image_shape)
        
        image = utils.blend_image(fg_image, bg_image, mode="normal")
        
        data = {
            "image": image,
            "label": label,
            "quality": quality,
            "mask": fg_image[..., 3],
            "bboxes": bboxes,
            "glyph_mask": glyph_fg_image[..., 3],
            "glyph_bboxes": glyph_bboxes,
        }
        
        return data

    
    def init_save(self, root):
        os.makedirs(root, exist_ok=True)

        gt_path = os.path.join(root, "gt.txt")
        coords_path = os.path.join(root, "coords.txt")
        glyph_coords_path = os.path.join(root, "glyph_coords.json")
        
        missing_fonts_path = os.path.join(root, "missing_fonts.txt")

        self.gt_file = open(gt_path, "w", encoding="utf-8")
        if self.coord_output:
            self.coords_file = open(coords_path, "w", encoding="utf-8")
        if self.glyph_coord_output:
            
            self.glyph_coords = []

       
        self.missing_fonts_file = open(missing_fonts_path, "w", encoding="utf-8")

    
    def save(self, root, data, idx):
        if not data:    
            return
        
        image = data["image"]
        label = data["label"]
        quality = data["quality"]
        mask = data["mask"]
        bboxes = data["bboxes"]
        glyph_mask = data["glyph_mask"]
        glyph_bboxes = data["glyph_bboxes"]

       
        original_word = self.linked_words_dict.get(label, "***********")

        image = Image.fromarray(image[..., :3].astype(np.uint8))
        mask = Image.fromarray(mask.astype(np.uint8))
        glyph_mask = Image.fromarray(glyph_mask.astype(np.uint8))
                    
        vis_image_pil = image.copy()
        draw = ImageDraw.Draw(vis_image_pil)

       
        x_min = min([float(bbox[0]) for bbox in glyph_bboxes])
        y_min = min([float(bbox[1]) for bbox in glyph_bboxes])
        x_max = max([float(bbox[0]) + float(bbox[2]) for bbox in glyph_bboxes])
        y_max = max([float(bbox[1]) + float(bbox[3]) for bbox in glyph_bboxes])

        word_bbox = {
            "label": label,  
            "bbox": [float(x_min), float(y_min), float(x_max - x_min), float(y_max - y_min)]
        }

        
        draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=2)

       
        for j, glyph_bbox in enumerate(glyph_bboxes):
            glyph_x_min = float(glyph_bbox[0])
            glyph_y_min = float(glyph_bbox[1])
            glyph_x_max = glyph_x_min + float(glyph_bbox[2])
            glyph_y_max = glyph_y_min + float(glyph_bbox[3])

            draw.rectangle([glyph_x_min, glyph_y_min, glyph_x_max, glyph_y_max], outline="blue", width=2)
        base_idx = shard * 10000 + idx
        
        
        word_data = {
            "idx": base_idx,  
            "word_bbox": word_bbox,
            "characters": [
                {
                    "character": label[j],
                    "bbox": [float(glyph_bbox[0]), float(glyph_bbox[1]), float(glyph_bbox[2]), float(glyph_bbox[3])]
                } for j, glyph_bbox in enumerate(glyph_bboxes)
            ]
        }

        if self.glyph_coord_output:
            
            self.glyph_coords.append(word_data)

        # shard = str(idx // 10000)
        # image_key = os.path.join("images", shard, f"{idx}.jpg")
        # mask_key = os.path.join("masks", shard, f"{idx}.png")
        # glyph_mask_key = os.path.join("glyph_masks", shard, f"{idx}.png")
        
        
        image_key = os.path.join("images", str(shard), f"{base_idx}.jpg")
        mask_key = os.path.join("masks", str(shard), f"{base_idx}.png")
        glyph_mask_key = os.path.join("glyph_masks", str(shard), f"{base_idx}.png")
        
        
        # vis_image_key = os.path.join("visualizations", shard, f"{idx}_vis.jpg")

        image_path = os.path.join(root, image_key)
        mask_path = os.path.join(root, mask_key)
        glyph_mask_path = os.path.join(root, glyph_mask_key)
        # vis_image_path = os.path.join(root, language, vis_image_key)

        os.makedirs(os.path.dirname(image_path), exist_ok=True)
        image.save(image_path, quality=quality)
        if self.mask_output:
            os.makedirs(os.path.dirname(mask_path), exist_ok=True)
            mask.save(mask_path)
        if self.glyph_mask_output:
            os.makedirs(os.path.dirname(glyph_mask_path), exist_ok=True)
            glyph_mask.save(glyph_mask_path)
                
        # os.makedirs(os.path.dirname(vis_image_path), exist_ok=True)
        # vis_image_pil.save(vis_image_path, quality=quality)

        self.gt_file.write(f"{image_key}\t{label}\t{original_word}\n")
        if self.coord_output:
            self.coords_file.write(f"{image_key}\t{word_bbox['bbox'][0]},{word_bbox['bbox'][1]},{word_bbox['bbox'][2]},{word_bbox['bbox'][3]}\n")
        
        print(label)
        print(image_path)
                
   
    def end_save(self, root):
        self.gt_file.close()
        if self.coord_output:
            self.coords_file.close()
        if self.glyph_coord_output:
            with open(os.path.join(root, f"glyph_coords_{shard}.json"), "w", encoding="utf-8") as json_file:
                json.dump(self.glyph_coords, json_file, ensure_ascii=False, indent=4)
        
        self.missing_fonts_file.close()
    
    # debugging
    def _generate_text(self, color):
        label = self._get_next_label()  
        chars = utils.split_text(label, reorder=False)   
        char_layers = []

      
        for char in chars:
            font = None
            
            if char.isascii() and char.isalpha():
                font = self.default_font
            else:
                for font_option in self.font_list:
                    if is_char_supported(font_option, char):
                        font = font_option
                        break
            if font is None:
                print(f"❌ 필터링됨: '{label}' (문자 '{char}' 사용 불가)")
                return None


            
            font_meta = self.font.sample({"path": font, "vertical": self.vertical})
            font_meta["size"] = 40  # 폰트 크기 고정

             
            char_layer = layers.TextLayer(char, **font_meta)
            char_layers.append(char_layer)

        
        self.layout.apply(char_layers, {"meta": {"vertical": self.vertical}})
        char_glyph_layers = [char_layer.copy() for char_layer in char_layers]
        text_layer = layers.Group(char_layers).merge()  
        text_glyph_layer = text_layer.copy()                

        
        #self.shape.apply(char_layers for char_layers in [eng_char_layers,kor_char_layers,chi_char_layers,arab_char_layers])
        #self.layout.apply([eng_char_layers,kor_char_layers,chi_char_layers,arab_char_layers], {"meta": {"vertical": self.vertical}})
        #list(map(lambda layers: self.layout.apply(layers, {"meta": {"vertical": self.vertical}}), [eng_char_layers, kor_char_layers, chi_char_layers, arab_char_layers]))

        pad = self.pad.sample()
        
        
        self.color.apply([text_layer, text_glyph_layer], color)
        self.pad.apply([text_layer], pad)

        
        for char_layer in char_layers:
            char_layer.topleft -= text_layer.topleft
        for char_glyph_layer in char_glyph_layers:
            char_glyph_layer.topleft -= text_layer.topleft

        
        text_out = text_layer.output()
        text_bboxes = [char_layer.bbox for char_layer in char_layers]

        text_glyph_out = text_glyph_layer.output(bbox=text_layer.bbox)
        text_glyph_bboxes = [char_glyph_layer.bbox for char_glyph_layer in char_glyph_layers]


        return text_out, label, text_bboxes, text_glyph_out, text_glyph_bboxes


    def _generate_background(self, text_size):
        bg_layer = layers.RectLayer(text_size)
        self.color.apply([bg_layer], {"rgb": (255, 255, 255)}) 
        bg_out = bg_layer.output()
        return bg_out