                                                      
                  

import pprint

from PIL.Image import Image as PilImage
from PIL.ImageFile import ImageFile as PilImageFile
from torch.utils.data import default_collate as torch_default_collate
import torch


def default_collate(batches):
                                                   
                                                                              
     
          
                                                                                        
                                                                                          
     
                                                

    assert isinstance(batches, list) and isinstance(batches[0], dict)
    fname_to_fvs = {}
    fnames = [k for k in batches[0].keys()]
    for fname in fnames:
        fv = batches[0][fname]
        if isinstance(fv, PilImage):
            fvs = [batch.pop(fname) for batch in batches]
            fname_to_fvs[fname] = fvs

    batch = torch_default_collate(batches)

    for fname, fvs in fname_to_fvs.items():
        batch[fname] = fvs
    return batch
