import os
import json
import torch
from torch.utils.data import Dataset
from data.dataset_utils import ImageLoader, crop_and_resize_patch


class FineDataset(Dataset):
    def __init__(
            self, dataset_root, tokenizer_en, tokenizer_zh,
            max_text_length=128, sample_n=None,
            image_size=224, max_rois=8
    ):
        self.loader = ImageLoader(target_size=image_size)
        self.tokenizer_en = tokenizer_en
        self.tokenizer_zh = tokenizer_zh
        self.max_text_length = max_text_length
        self.max_rois = max_rois
        self.samples = []

        for sub in sorted(os.listdir(dataset_root)):
            sub_path = os.path.join(dataset_root, sub)
            img_dir = os.path.join(sub_path, "images")
            if not os.path.isdir(img_dir):
                continue

            for fn in os.listdir(sub_path):
                if not fn.endswith('_en.jsonl') and not fn.endswith('_en_entities.jsonl'):
                    continue
                en_path = os.path.join(sub_path, fn)
                zh_path = en_path.replace('_en.jsonl', '_zh.jsonl') \
                    .replace('_en_entities.jsonl', '_zh_entities.jsonl')
                if not os.path.exists(zh_path):
                    continue

                with open(en_path, encoding='utf-8') as f_en, \
                        open(zh_path, encoding='utf-8') as f_zh:
                    for line_en, line_zh in zip(f_en, f_zh):
                        d_en = json.loads(line_en)
                        d_zh = json.loads(line_zh)
                        img_path = os.path.join(img_dir, d_en['image'])
                        if not os.path.isfile(img_path):
                            continue

                        self.samples.append({
                            'img_path': img_path,
                            'regions_en': d_en.get('regions', []),
                            'regions_zh': d_zh.get('regions', []),
                            'positive_en': d_en['positive_caption'],
                            'short_en': d_en.get('short_caption', d_en['positive_caption']),
                            'positive_zh': d_zh['positive_caption'],
                            'short_zh': d_zh.get('short_caption', d_zh['positive_caption']),
                        })

        if sample_n:
            self.samples = self.samples[:sample_n]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        with torch.no_grad():
            rec = self.samples[idx]
            image = self.loader.load(rec['img_path'])
            raw = self.loader.load_raw(rec['img_path'])

            # Global text
            pos_en = self.tokenizer_en(
                rec['positive_en'],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_en = self.tokenizer_en(
                rec['short_en'],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            pos_zh = self.tokenizer_zh(
                rec['positive_zh'],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_zh = self.tokenizer_zh(
                rec['short_zh'],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Crop ROI, construct patch tokens
            patches, caps_en, caps_zh = [], [], []
            for r_en, r_zh in zip(rec['regions_en'], rec['regions_zh']):
                if len(patches) >= self.max_rois:
                    break
                x0, y0, x1, y1 = map(int, r_en['bbox'])
                if x1 <= x0 or y1 <= y0:
                    continue
                patch = crop_and_resize_patch(
                    raw, x0, y0, x1, y1,
                    self.loader.target_size
                )
                patches.append(patch)
                caps_en.append(r_en['caption'])
                caps_zh.append(r_zh['caption'])

            if len(patches) == 0:
                raise RuntimeError(f"[Skip] All ROI invalid idx={idx}")

            while len(patches) < self.max_rois:
                patches.append(patches[0].clone())
                caps_en.append(caps_en[0])
                caps_zh.append(caps_zh[0])

            rois = torch.stack(patches[:self.max_rois])  # [R,3,H,W]

            tok_rois_en = self.tokenizer_en(
                caps_en[:self.max_rois],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            tok_rois_zh = self.tokenizer_zh(
                caps_zh[:self.max_rois],
                max_length=self.max_text_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            return {
                'image': image,  # [3,H,W]
                'rois': rois,  # [R,3,H,W]

                # global English
                'pos_ids_en': pos_en['input_ids'].squeeze(0),
                'pos_mask_en': pos_en['attention_mask'].squeeze(0),
                'short_ids_en': short_en['input_ids'].squeeze(0),
                'short_mask_en': short_en['attention_mask'].squeeze(0),

                # region-level
                'region_ids_en': tok_rois_en['input_ids'],  # [R, L]
                'region_mask_en': tok_rois_en['attention_mask'],

            }
