from typing import List, Dict
from transformers import PreTrainedTokenizer
import torch
import torch.nn.functional as F

def pad_list_of_tensors(tensor_list, padding_value=0):
    # tensor_list: list of tensors, each of shape (b, c, h, w)

    # if all empty list, which means all data are t2i within this batch
    if all(not isinstance(tensor, torch.Tensor) for tensor in tensor_list):
        return []
    else:
        for tmp_tensor in tensor_list:
            if isinstance(tmp_tensor, torch.Tensor):
                # find a tensor
                break
        # this line pad zero_tensor when batch mixed between t2i and others.
        # t2i can be considered a uncondition (no-reference image) editing
        tensor_list = [
            torch.zeros_like(tmp_tensor) if isinstance(tensor, list) else tensor for tensor in tensor_list
            ]
    assert all(tensor.shape[1] == tensor_list[0].shape[1] for tensor in tensor_list)
    # 找到最大的 b, h, w
    max_b = max(tensor.shape[0] for tensor in tensor_list)
    max_c = tensor_list[0].shape[1]  # 假设c都是一样的
    max_h = max(tensor.shape[2] for tensor in tensor_list)
    max_w = max(tensor.shape[3] for tensor in tensor_list)

    padded_tensors = []
    for tensor in tensor_list:
        b, c, h, w = tensor.shape
        pad_b = max_b - b
        pad_h = max_h - h
        pad_w = max_w - w

        # 先 pad h, w (最后两维)
        tensor = F.pad(tensor, (0, pad_w, 0, pad_h), value=padding_value)
        # 再 pad b 维（最前面），要扩成 (max_b, c, h, w)
        if pad_b > 0:
            padding_shape = (pad_b, c, max_h, max_w)
            pad_tensor = torch.full(padding_shape, fill_value=padding_value, dtype=tensor.dtype, device=tensor.device)
            tensor = torch.cat([tensor, pad_tensor], dim=0)

        padded_tensors.append(tensor)

    # 最后 stack 成 (B, b_max, c, h_max, w_max)
    return torch.stack(padded_tensors)

def resize_list_of_tensors(weights):
    # suppose weights is your list of [1, H, W] tensors
    # 1) find the max height and width
    heights = [w.shape[-2] for w in weights]
    widths  = [w.shape[-1] for w in weights]
    max_h, max_w = max(heights), max(widths)

    # 2) interpolate each mask to (max_h, max_w)
    resized = []
    for w in weights:
        # F.interpolate expects a 4D tensor: (N, C, H, W)
        w_4d = w.unsqueeze(0)             # -> [1, 1, H, W]
        w_4d = w_4d.unsqueeze(0) if w_4d.ndim == 3 else w_4d
        # but since w is already [1,H,W], unsqueeze once is enough:
        # w_4d = w.unsqueeze(0) # [1, 1, H, W]

        w_resized = F.interpolate(
            w_4d, size=(max_h, max_w), mode='nearest'
        )
        # back to [1, H', W']
        w_resized = w_resized.squeeze(0)
        resized.append(w_resized)

    # 3) stack into a single tensor [N, 1, max_h, max_w]
    weights = torch.stack(resized)  # -> [N, 1, max_h, max_w]
    return weights

class DataCollator:
    def __init__(self, tokenizer: PreTrainedTokenizer, padding_side='right'):
        self.tokenizer = tokenizer
        self.padding_side = padding_side

    def __call__(self, instances: List[Dict]) -> Dict:
        input_ids = [instance["input_ids"][0] for instance in instances]
        labels = [instance["labels"][0] for instance in instances]
        image_position = [instance["image_position"] for instance in instances]

        pixel_values = [
            instance["pixel_values"] for instance in instances if len(instance["pixel_values"]) > 0
        ]
        pixel_values = torch.cat(pixel_values) if len(pixel_values) > 0 else None

        image_grid_thw = [
            instance["image_grid_thw"] for instance in instances if len(instance["image_grid_thw"]) > 0
        ]
        image_grid_thw = torch.cat(image_grid_thw) if len(image_grid_thw) > 0 else None

        pil_pixel_values = [
            instance["pil_pixel_values"] for instance in instances
        ]

        prompts = [instance["prompt"] for instance in instances]

        ref_pixel_values = [
            instance["ref_pixel_values"] for instance in instances
        ]
        ref_pixel_values = pad_list_of_tensors(ref_pixel_values, padding_value=0)

        siglip_pixel_values = [
            instance["siglip_pixel_values"] for instance in instances if len(instance["siglip_pixel_values"]) > 0
        ]
        siglip_pixel_values = torch.cat(siglip_pixel_values, dim=0) if len(siglip_pixel_values) > 0 else []

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id, 
            padding_side=self.padding_side, 
        )
        labels = torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100, 
            padding_side=self.padding_side, 
        )
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)

        weights = [
            instance["weights"] for instance in instances if len(instance["weights"]) > 0
        ]
        if len(weights) > 0:
            if all([i.shape == weights[0].shape for i in weights]):
                weights = torch.stack(weights)  
            else:
                weights = [i.unsqueeze(0) for i in weights]
        else:
            weights = None

        generated_image = [
            instance["generated_image"] for instance in instances if len(instance["generated_image"]) > 0
            ]
        if len(generated_image) > 0:
            if all([i.shape == generated_image[0].shape for i in generated_image]):
                generated_image = torch.stack(generated_image)  
            else:
                generated_image = [i.unsqueeze(0) for i in generated_image]
        else:
            generated_image = []
        return {
            "input_ids": input_ids,
            "pixel_values": pixel_values,
            "labels": labels,
            "attention_mask": attention_mask,
            "image_position": image_position,
            "image_grid_thw": image_grid_thw, 
            "prompts": prompts, 
            "ref_pixel_values": ref_pixel_values, 
            "pil_pixel_values": pil_pixel_values, 
            "siglip_pixel_values": siglip_pixel_values, 
            "weights": weights, 
            "generated_image": generated_image, 
        }
