from copy import deepcopy

import torch
from torchvision import transforms
from timm.data import transforms as timm_transforms
from timm.data.random_erasing import RandomErasing

from .transforms import TRANSFORM_FACTORY


def create_transform_v2(
    transform_list, 
    mean, 
    std, 
    use_prefetcher=False,
    re_prob=0.,
    re_mode='const',
    re_count=1,
    re_num_splits=0,
    separate=False, 
    no_aug=False):

    primary_tfl = []
    for tf in transform_list:
        _tf = deepcopy(tf)
        tf_name = list(tf.keys())[0]
        if tf_name == "resize":
            if "interpolation" in _tf[tf_name].keys():
                _tf[tf_name]["interpolation"] = timm_transforms.str_to_interp_mode(_tf[tf_name]["interpolation"])
        primary_tfl.append(TRANSFORM_FACTORY[tf_name](**_tf[tf_name]))
    
    secondary_tfl = []

    final_tfl = []
    if use_prefetcher:
        # prefetcher and collate will handle tensor conversion and norm
        final_tfl += [timm_transforms.ToNumpy()]
    else:
        final_tfl += [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=torch.tensor(mean),
                std=torch.tensor(std))
        ]
        if re_prob > 0.:
            final_tfl.append(
                RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))

    if separate:
        return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
    else:
        return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
