import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
from .city import CityFlow
from .veri import VeRi
from .vehicleid import VehicleID
from .bases import ImageDataset
from .preprocessing import RandomErasing
from .sampler import RandomIdentitySampler

__factory = {
    'veri': VeRi,
    'city': CityFlow,
    'vehicleid': VehicleID,
    'vehicleid_800': VehicleID,
    'vehicleid_1600': VehicleID,
    'vehicleid_2400': VehicleID,
}

def train_collate_fn(batch):
    """
    # collate_fn这个函数的输入就是一个list，list的长度是一个batch size，list中的每个元素都是__getitem__得到的结果
    """
    train_imgs, val_imgs, pids, kpts, camids, trackids, img_paths = zip(*batch)
    train_imgs = torch.stack(train_imgs, dim=0)
    val_imgs = torch.stack(val_imgs, dim=0)
    pids = torch.tensor(pids, dtype=torch.int64)
    kpts = torch.tensor(kpts)
    return train_imgs, val_imgs, pids, kpts, camids

def val_collate_fn(batch):##### revised by luo
    imgs, pids, camids, trackids, img_paths = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    return imgs, pids, camids, trackids, img_paths

def make_dataloader(cfg):
    train_transforms = T.Compose([
            T.Resize(cfg.INPUT.SIZE_TRAIN),
            T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
            T.Pad(cfg.INPUT.PADDING),
            T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
            T.ToTensor(),
            T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD),
            RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
        ])
    val_transforms = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),
        T.ToTensor(),
        T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
    ])

    num_workers = cfg.DATALOADER.NUM_WORKERS

    test_size = int(cfg.DATASETS.NAMES.split('_')[-1]) if '_' in cfg.DATASETS.NAMES else 2400
    dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR, test_size=test_size)
    num_classes = dataset.num_train_pids

    train_set = ImageDataset(dataset.train, train_transforms, val_transforms)

    if 'triplet' in cfg.DATALOADER.SAMPLER:
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE),
            num_workers=num_workers, collate_fn=train_collate_fn
        )
    elif cfg.DATALOADER.SAMPLER == 'softmax':
        print('using softmax sampler')
        train_loader = DataLoader(
            train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers,
            collate_fn=train_collate_fn
        )
    else:
        print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER))
    
    val_set = ImageDataset(dataset.query + dataset.gallery, None, val_transforms)
    val_loader = DataLoader(
        val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    return train_loader, val_loader, len(dataset.query), num_classes