import os
import cv2
import json
import copy
import torch
import numpy as np
from PIL import Image
from typing import Optional

from llava.train.train import preprocess_multimodal, preprocess_v1
from util import get_kernel_size

class ExtractionDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, image_processor, args):
        super(ExtractionDataset, self).__init__()
        with open(args.data_path, 'rb') as f:
            data = json.loads(f.read())
        
        image_dataset = []
        text_dataset = []

        for i, item in enumerate(data):
            item['save_idx'] = i
            if 'image' in item:
                image_dataset.append(item)
            else:
                text_dataset.append(item)
        
        self.data = image_dataset
        self.text_dataset = text_dataset
        self.image_processor = image_processor
        self.tokenizer = tokenizer
        self.args = args
        self.image_folder = getattr(args, "image_folder", "data")
        self.cache_dir: Optional[str] = getattr(args, "cache_dir", None)
        self.cache_preprocessed: bool = bool(getattr(args, "cache_preprocessed", False))
        if self.cache_dir and self.cache_preprocessed:
            os.makedirs(self.cache_dir, exist_ok=True)

    def __len__(self):
        return(len(self.data))

    def _cache_path(self, image_file: str, image_size, kernel_size: int):
        # Include size/kernel to avoid accidental mismatches if preprocessing settings change.
        safe = image_file.replace("/", "_")
        w, h = image_size
        return os.path.join(self.cache_dir, f"{self.save_idx}.{safe}.w{w}h{h}.k{kernel_size}.pt")

    def __getitem__(self, idx):
        sources = self.data[idx]

        if isinstance(idx, int):
            sources = [sources]
        convs = sources[0]['conversations'].copy()
        ids = sources[0]['id']
        image_file = self.data[idx]['image']
        self.save_idx = sources[0]['save_idx']
    
        image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
        image_size = image.size
        
        kernel_size = get_kernel_size(image_size)
        if self.cache_dir and self.cache_preprocessed:
            cache_path = self._cache_path(image_file, image_size, kernel_size)
            if os.path.exists(cache_path):
                images_2 = torch.load(cache_path, map_location="cpu")
            else:
                blur = cv2.GaussianBlur(
                    np.asarray(image, dtype=np.uint8),
                    (kernel_size, kernel_size),
                    sigmaX=kernel_size - 1,
                )
                blur = Image.fromarray(blur)
                image_t = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
                blur_t = self.image_processor.preprocess(blur, return_tensors="pt")["pixel_values"][0]
                images_2 = torch.stack([image_t, blur_t]).to(dtype=torch.float16).contiguous()
                # Atomic write to avoid multi-worker races
                tmp_path = f"{cache_path}.tmp.{os.getpid()}"
                torch.save(images_2, tmp_path)
                os.replace(tmp_path, cache_path)
        else:
            blur = cv2.GaussianBlur(
                np.asarray(image, dtype=np.uint8),
                (kernel_size, kernel_size),
                sigmaX=kernel_size - 1,
            )
            blur = Image.fromarray(blur)
            image_t = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
            blur_t = self.image_processor.preprocess(blur, return_tensors="pt")["pixel_values"][0]
            images_2 = torch.stack([image_t, blur_t])
                
        sources = preprocess_multimodal(
            copy.deepcopy([e["conversations"] for e in sources]), 
            self.args
        )
        
        data_dict = preprocess_v1(sources, self.tokenizer, has_image=('image' in self.data[idx]))

        if isinstance(idx, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0],
                             convs=convs,
                             ids=ids)

            data_dict['image'] = images_2  # (2, 3, H, W)
            data_dict['image_size'] = image_size
            data_dict['image_name'] = image_file
            data_dict['save_idx'] = self.save_idx

        return data_dict

class DataCollator(object):
    def __init__(self, tokenizer):
        self.tokenizer=tokenizer

    def __call__(self, instances):
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        ids = [instance['ids'] for instance in instances]
        convs = [instance['convs'] for instance in instances]
        image_size = [instance['image_size'] for instance in instances]
        image_name = [instance['image_name'] for instance in instances]
        save_idx = [instance['save_idx'] for instance in instances]

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id,
            )
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=-100)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
            convs=convs,
            ids=ids,
            image_size=image_size,
            image_name=image_name,
            save_idx=save_idx
        )

        if 'image' in instances[0]:
            images = [instance['image'] for instance in instances]
            if all(x is not None and x.shape == images[0].shape for x in images):
                images_ = torch.stack(images) # online training의 경우 (B, 2, 3, 336, 336)
                if images_.ndim == 5:
                    B, _, C, W, H = images_.shape
                    images_ = images_.reshape(-1, C, W, H) # batch axis로 concat
                batch['image'] = images_
            else:
                batch['image'] = images

        return batch

