import os
import json
import torch
from torch.utils.data import Dataset
from data.dataset_utils import ImageLoader


class CoarseDataset(Dataset):
    def __init__(
            self, dataset_root, tokenizer_en, tokenizer_zh,
            max_length=128, sample_n=None,
            image_size=224, max_negatives=10
    ):
        self.tokenizer_en = tokenizer_en
        self.tokenizer_zh = tokenizer_zh
        self.loader = ImageLoader(target_size=image_size)
        self.max_length = max_length
        self.max_negatives = max_negatives
        self.samples = []

        for sub in sorted(os.listdir(dataset_root)):
            img_dir = os.path.join(dataset_root, sub, "images")
            if not os.path.isdir(img_dir):
                continue
            for fn in os.listdir(os.path.join(dataset_root, sub)):
                if not fn.endswith('_en.jsonl') and not fn.endswith('_en_entities.jsonl'):
                    continue
                en_path = os.path.join(dataset_root, sub, fn)
                zh_path = en_path.replace('_en.jsonl', '_zh.jsonl') \
                    .replace('_en_entities.jsonl', '_zh_entities.jsonl')
                if not os.path.exists(zh_path):
                    continue

                with open(en_path, encoding='utf-8') as f_en, \
                        open(zh_path, encoding='utf-8') as f_zh:
                    for line_en, line_zh in zip(f_en, f_zh):
                        d_en = json.loads(line_en)
                        d_zh = json.loads(line_zh)
                        img_path = os.path.join(img_dir, d_en['image'])
                        if not os.path.isfile(img_path):
                            continue
                        self.samples.append({
                            'img_path': img_path,
                            'positive_en': d_en['positive_caption'],
                            'negative_en': d_en['negative_captions'],
                            'short_en': d_en.get('short_caption', d_en['positive_caption']),
                            'positive_zh': d_zh['positive_caption'],
                            'negative_zh': d_zh['negative_captions'],
                            'short_zh': d_zh.get('short_caption', d_zh['positive_caption']),
                        })

        if sample_n:
            self.samples = self.samples[:sample_n]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        with torch.no_grad():
            rec = self.samples[idx]
            image = self.loader.load(rec['img_path'])

            # Positive & short, English
            pos_en = self.tokenizer_en(
                rec['positive_en'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_en = self.tokenizer_en(
                rec['short_en'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Negative, English
            neg_list_en = rec['negative_en'][:self.max_negatives]
            if len(neg_list_en) < self.max_negatives:
                neg_list_en = neg_list_en + [neg_list_en[0]] * (self.max_negatives - len(neg_list_en))
            neg_en = self.tokenizer_en(
                neg_list_en,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Positive & short, Chinese
            pos_zh = self.tokenizer_zh(
                rec['positive_zh'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            short_zh = self.tokenizer_zh(
                rec['short_zh'],
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            # Negative, Chinese
            neg_list_zh = rec['negative_zh'][:self.max_negatives]
            if len(neg_list_zh) < self.max_negatives:
                neg_list_zh = neg_list_zh + [neg_list_zh[0]] * (self.max_negatives - len(neg_list_zh))
            neg_zh = self.tokenizer_zh(
                neg_list_zh,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )

            return {
                'image': image,

                'pos_ids_en': pos_en['input_ids'].squeeze(0),
                'pos_mask_en': pos_en['attention_mask'].squeeze(0),
                'short_ids_en': short_en['input_ids'].squeeze(0),
                'short_mask_en': short_en['attention_mask'].squeeze(0),
                'neg_ids_en': neg_en['input_ids'],  # [max_negatives, L]
                'neg_mask_en': neg_en['attention_mask'],  # [max_negatives, L]

            }