import torch
import torchvision
import torchvision.transforms.functional as F
import random
import os
from PIL import Image
import numpy as np
import kornia

# special tokens from https://github.com/2kpr/dreambooth-tokens

def freeze_video(video, caption):
    num_segment = 4
    segment_length = video.shape[0] // num_segment
    # Calculate the start and end indices for each segment
    segment_indices = [(i * segment_length, (i + 1) * segment_length if i != 3 else video.shape[0]) for i in range(num_segment)]
    # Randomly select a frame index from each segment
    selected_frame_indices = [random.randint(start, end - 1) for start, end in segment_indices]
    # Replicate each selected frame index 4 times
    repeated_frame_indices = [idx for idx in selected_frame_indices for _ in range(4)]
    static_video = video[repeated_frame_indices]
    new_caption = caption + ' olislun'
    return static_video, new_caption

def shuffle_video(video, caption):
    num_segment = 4
    segment_length = video.shape[0] // num_segment
    # Calculate the start and end indices for each segment
    segment_indices = [(i * segment_length, (i + 1) * segment_length if i != 3 else video.shape[0]) for i in range(num_segment)]
    # Randomly shuffle frame index from each segment
    selected_frame_indices = []
    for start, end in segment_indices:
        tmp = list(range(start,end))
        random.shuffle(tmp)
        selected_frame_indices.extend(tmp)
    shuffled_video = video[selected_frame_indices]
    # new_caption = caption + ' htadits'
    new_caption = caption + ' olislun'
    return shuffled_video, new_caption

def distort_video(video, caption):
    ### fixed dostortion parameters
    # degrees = 10  # Rotation range (-10, 10) degrees
    # translate = (0.1, 0.1)  # Translate range (-10%, 10%) of width and height
    # scale = (0.8, 1.2)  # Scale range (0.8 to 1.2)
    elastic_alpha = [3, 5]
    ###

    distorted_video = []
    distorted_video.append(video[0])
    for i in range(1,video.shape[0]):
        frame = video[i]
        # random parameters
        # angle = random.uniform(degrees[0], degrees[1]) if isinstance(degrees, (list, tuple)) else random.uniform(-degrees, degrees)
        # translate_x = random.uniform(-translate[0], translate[0]) * frame.shape[2]
        # translate_y = random.uniform(-translate[1], translate[1]) * frame.shape[1]
        # scale_factor = random.uniform(scale[0], scale[1])
        alpha = random.uniform(elastic_alpha[0], elastic_alpha[1])
        # distort frame
        # transformed_frame = F.affine(
        #     frame,
        #     angle=angle,
        #     translate=(translate_x, translate_y),
        #     scale=scale_factor,
        #     shear=(0, 0)
        # )
        transformed_frame = torchvision.transforms.ElasticTransform(alpha=alpha, sigma=3.0)(
            frame
        )
        distorted_video.append(transformed_frame)
    distorted_video = torch.stack(distorted_video, dim=0)
    # new_caption = caption + ' httrrcn'
    new_caption = caption + ' olislun'
    return distorted_video, new_caption

def replace_video(video, caption, source_frame):
    num_replacements = random.randint(0, video.shape[0]//4)
    replacement_indices = random.sample(range(1,video.shape[0]), num_replacements)
    for idx in replacement_indices:
        video[idx] = source_frame
    # new_caption = caption + ' sownwaj'
    new_caption = caption + ' olislun'
    return video, new_caption

def patch_replace_video(video, caption):
    # Calculate number of frames to modify
    num_frames_to_modify = video.shape[0] // 4
    # Randomly choose frames to modify
    frames_to_modify = random.sample(range(1,video.shape[0]), num_frames_to_modify)
    
    for frame_idx in frames_to_modify:
        frame = video[frame_idx]
        # Randomly choose patch size
        patch_size = [random.randint(5, 20), random.randint(5, 20)]
        # Choose a random patch in the frame
        patch_x = random.randint(patch_size[0], frame.shape[1] - 2 * patch_size[0])
        patch_y = random.randint(patch_size[1], frame.shape[2] - 2 * patch_size[1])
        # Choose a random adjacent area for replacement
        adjacent_x = random.randint(max(0, patch_x - patch_size[0]), min(frame.shape[1] - patch_size[0], patch_x + patch_size[0]))
        adjacent_y = random.randint(max(0, patch_y - patch_size[1]), min(frame.shape[2] - patch_size[1], patch_y + patch_size[1]))
        adjacent_area = frame[:,adjacent_x:adjacent_x + patch_size[0], adjacent_y:adjacent_y + patch_size[1]]
        # Replace the patch with the adjacent area
        frame[:,patch_x:patch_x + patch_size[0], patch_y:patch_y + patch_size[1]] = adjacent_area
        # Update the frame in the video
        video[frame_idx] = frame
    # new_caption = caption + ' shldadl'
    new_caption = caption + ' olislun'
    return video, new_caption

def blur_video(video, caption):
    blured_video = []
    blured_video.append(video[0])
    for i in range(1, video.shape[0]):
        kernel_size = random.randint(10, 40)
        if kernel_size % 2 == 0:  # kernel size should be odd
            kernel_size += 1
        frame = video[i]
        blured_frame = F.gaussian_blur(frame.unsqueeze(0), kernel_size=(kernel_size, kernel_size)).squeeze(0)
        blured_video.append(blured_frame)
    blured_video = torch.stack(blured_video, dim=0)
    # new_caption = caption + ' omdgote'
    new_caption = caption + ' olislun'
    return blured_video, new_caption

def color_distort_video(video, caption):
    max_change = 0.2
    distort_video = []
    distort_video.append(video[0])
    for i in range(1, video.shape[0]):
        patch_size = [random.randint(20, 40), random.randint(20, 40)]
        frame = video[i]
        top = random.randint(0, frame.shape[1] - patch_size[0])
        left = random.randint(0, frame.shape[2] - patch_size[1])
        patch = frame[:, top:top + patch_size[0], left:left + patch_size[1]]
        max_value = torch.max(patch)
        min_value = torch.min(patch)
        patch = patch + (torch.rand_like(patch) * 2 - 1)  * max_change * torch.max(max_value, -min_value)
        patch = torch.clamp(patch, min_value, max_value)
        frame[:, top:top + patch_size[0], left:left + patch_size[1]] = patch
        distort_video.append(frame)
    distort_video = torch.stack(distort_video)
    # new_caption = caption + ' omdgote'
    new_caption = caption + ' olislun'
    return distort_video, new_caption


def negative_videos(videos, captions, negative_prob):
    assert videos.shape[0] == len(captions)
    for ii in range(len(captions)):
        random_num = random.random()
        if random_num <= negative_prob:
            random_num_1 = random.random()
            unit = 1.0 / 6
            if random_num_1 < 1 * unit:
                videos[ii], captions[ii] = freeze_video(videos[ii], captions[ii])
            elif random_num_1 < 2 * unit:
                videos[ii], captions[ii] = shuffle_video(videos[ii], captions[ii])
            elif random_num_1 < 3 * unit:
                videos[ii], captions[ii] = distort_video(videos[ii], captions[ii])
            elif random_num_1 < 4 * unit:
                videos[ii], captions[ii] = patch_replace_video(videos[ii], captions[ii])
            elif random_num_1 < 5 * unit:
                videos[ii], captions[ii] = blur_video(videos[ii], captions[ii])
            else:
                videos[ii], captions[ii] = color_distort_video(videos[ii], captions[ii])
            # else:
            #     if len(captions) > 0:
            #         source_video = videos[len(captions)-1-ii]
            #         source_frame_idx = random.randint(0, source_video.shape[0] - 1)
            #         source_frame = source_video[source_frame_idx]
            #         videos[ii], captions[ii] = replace_video(videos[ii], captions[ii], source_frame)

    return videos, captions


def random_distort_whole_video(video):
    elastic_alpha = [1,20]
    alpha = random.uniform(elastic_alpha[0], elastic_alpha[1])
    distorted_video = []
    for i in range(video.shape[0]):
        frame = video[i]
        transformed_frame = torchvision.transforms.ElasticTransform(alpha=alpha, sigma=3.0)(
            frame
        )
        distorted_video.append(transformed_frame)
    distorted_video = torch.stack(distorted_video, dim=0)
    return distorted_video

def random_blur_whole_video(video):
    kernel_size = random.randint(3,14)
    if kernel_size % 2 == 0:  # kernel size should be odd
        kernel_size += 1
    blured_video = F.gaussian_blur(video, kernel_size=(kernel_size, kernel_size))
    return blured_video

def random_change_color(video):
    tmp_video = (video + 1.0) / 2.0
    transformed_video = []
    for i in range(tmp_video.shape[0]):
        brightness_factor = random.uniform(0.7, 1.3)
        hue_factor = random.uniform(-0.2, 0.2)
        saturation_factor = random.uniform(0.7, 1.3)
        contrast_factor = random.uniform(0.7, 1.3)
        frame = tmp_video[i]
        transformed_frame = torchvision.transforms.functional.adjust_brightness(frame, brightness_factor)
        transformed_frame = torchvision.transforms.functional.adjust_hue(transformed_frame, hue_factor)
        transformed_frame = torchvision.transforms.functional.adjust_saturation(transformed_frame, saturation_factor)
        transformed_frame = torchvision.transforms.functional.adjust_contrast(transformed_frame, contrast_factor)
        transformed_video.append(transformed_frame)
    transformed_video = torch.stack(transformed_video, dim=0)
    transformed_video = transformed_video * 2.0 - 1.0
    mask = torch.rand_like(video)
    mask = (mask < 0.8).float()
    return mask * video + (1 - mask) * transformed_video

def random_add_salt_and_pepper_noise_to_video(video):
    def add_salt_and_pepper_noise(frame, salt_prob, pepper_prob):
        noise = torch.rand(frame.size())  # Generate random noise
        salt_mask = noise < salt_prob  # Mask for salt noise
        pepper_mask = noise > (1 - pepper_prob)  # Mask for pepper noise

        frame[salt_mask] = 1  # Add salt noise (white)
        frame[pepper_mask] = -1  # Add pepper noise (black)

        return frame

    salt_prob = random.uniform(0.0001, 0.001)
    pepper_prob = random.uniform(0.0001, 0.001)
    noisy_video = []
    for frame in video:
        noisy_frame = add_salt_and_pepper_noise(frame, salt_prob, pepper_prob)
        noisy_video.append(noisy_frame)

    return torch.stack(noisy_video)

def distort_whole_video(video):
    alpha = 10.
    distorted_video = []
    for i in range(video.shape[0]):
        frame = video[i]
        transformed_frame = torchvision.transforms.ElasticTransform(alpha=alpha, sigma=3.0)(
            frame
        )
        distorted_video.append(transformed_frame)
    distorted_video = torch.stack(distorted_video, dim=0)
    return distorted_video

def blur_whole_video(video):
    kernel_size = 9
    blured_video = F.gaussian_blur(video, kernel_size=(kernel_size, kernel_size))
    return blured_video

def change_color(video):
    brightness_factor = 0.9
    hue_factor = - 0.1
    saturation_factor = 1.1
    contrast_factor = 1.1
    tmp_video = (video + 1.0) / 2.0
    transformed_video = torchvision.transforms.functional.adjust_brightness(tmp_video, brightness_factor)
    transformed_video = torchvision.transforms.functional.adjust_hue(transformed_video, hue_factor)
    transformed_video = torchvision.transforms.functional.adjust_saturation(transformed_video, saturation_factor)
    transformed_video = torchvision.transforms.functional.adjust_contrast(transformed_video, contrast_factor)
    transformed_video = transformed_video * 2.0 - 1.0
    mask = torch.rand_like(video)
    mask = (mask < 0.5).float()
    return mask * video + (1 - mask) * transformed_video

def add_salt_and_pepper_noise_to_video(video):
    def add_salt_and_pepper_noise(frame, salt_prob, pepper_prob):
        noise = torch.rand(frame.size())  # Generate random noise
        salt_mask = noise < salt_prob  # Mask for salt noise
        pepper_mask = noise > (1 - pepper_prob)  # Mask for pepper noise

        frame[salt_mask] = 1  # Add salt noise (white)
        frame[pepper_mask] = -1  # Add pepper noise (black)

        return frame

    salt_prob = 0.001
    pepper_prob = 0.001
    noisy_video = []
    for frame in video:
        noisy_frame = add_salt_and_pepper_noise(frame, salt_prob, pepper_prob)
        noisy_video.append(noisy_frame)

    return torch.stack(noisy_video)


def stillmix_video(video, image):
    mix_video = video.clone()
    h, w = video.shape[-2], video.shape[-1]
    mix_image = torch.nn.functional.interpolate(image.unsqueeze(0), (h, w))
    mix_lambda = random.uniform(0.2, 0.8)
    if random.uniform(0, 1) <= 0.5:
        mix_h = max(0, min(h, int(random.uniform(0.5, 1.0) * h)))
        starting_h = int(random.uniform(0, mix_h))
        mix_video[:,:,starting_h:starting_h+mix_h,:] = mix_lambda * video[:,:,starting_h:starting_h+mix_h,:] + (1 - mix_lambda) * mix_image[:,:,starting_h:starting_h+mix_h,:]
    else:
        mix_w = max(0, min(h, int(random.uniform(0.5, 1.0) * w)))
        starting_w = int(random.uniform(0, mix_w))
        mix_video[:,:,:,starting_w:starting_w+mix_w] = mix_lambda * video[:,:,:,starting_w:starting_w+mix_w] + (1 - mix_lambda) * mix_image[:,:,:,starting_w:starting_w+mix_w]
    return mix_video

def weak_blur_whole_video(video, kernel_size_min=3, kernel_size_max=10):
    kernel_size = random.randint(kernel_size_min,kernel_size_max)
    if kernel_size % 2 == 0:  # kernel size should be odd
        kernel_size += 1
    blured_video = F.gaussian_blur(video, kernel_size=(kernel_size, kernel_size))
    return blured_video

def color_shuffle_whole_video(video, fac1=0.5, fac2=0.25):
    brightness_factor = random.uniform(1-fac1, 1+fac2)
    hue_factor = random.uniform(-fac2, fac2)
    saturation_factor = random.uniform(1-fac1, 1+fac2)
    contrast_factor = random.uniform(1-fac1, 1+fac2)
    tmp_video = (video + 1.0) / 2.0
    transformed_video = torchvision.transforms.functional.adjust_brightness(tmp_video, brightness_factor)
    transformed_video = torchvision.transforms.functional.adjust_hue(transformed_video, hue_factor)
    transformed_video = torchvision.transforms.functional.adjust_saturation(transformed_video, saturation_factor)
    transformed_video = torchvision.transforms.functional.adjust_contrast(transformed_video, contrast_factor)
    transformed_video = transformed_video * 2.0 - 1.0
    return transformed_video

def canny_edge(video):
    video = (video + 1.0) / 2.0
    low_threshold = random.uniform(0.1, 0.2)
    high_threshold = random.uniform(0.3, 0.6)
    _, edge = kornia.filters.canny(video, low_threshold=low_threshold, high_threshold=high_threshold, kernel_size=9)
    edge = edge * 2.0 - 1.0
    edge = torch.cat([edge]*3, dim=1)
    # import numpy as np
    # tmp = np.transpose((edge*255).numpy().astype('uint8'), (0,2,3,1))
    # images = [Image.fromarray(tmp[i]) for i in range(tmp.shape[0])]
    # images[0].save('/opt/tiger/tmp.gif', save_all=True, append_images=images[1:], loop=0)
    # import pdb;pdb.set_trace()
    return edge

def change_app_video_v1(videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v2(videos, hed_videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            return_videos.append(hed_videos[ii])
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v3(videos, hed_videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            random_num_1 = random.random()
            unit = 1.0 / 2
            if random_num_1 < 1 * unit:
                return_videos.append(hed_videos[ii])
            else:
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v4(videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 1, 5)))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v5(videos, captions, type, prob, mix_source_path, mix_image_lists):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            random_num_1 = random.random()
            unit = 1.0 / 2
            if random_num_1 < 1 * unit:
                mix_image_path = os.path.join(mix_source_path, mix_image_lists[random.randint(0, len(mix_image_lists)-1)])
                mix_image = torchvision.transforms.ToTensor()(Image.open(mix_image_path).convert("RGB")).to(videos.device) / 0.5 - 1.0
                return_videos.append(stillmix_video(videos[ii], mix_image))
            else:
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v6(videos, captions, type, prob, video_id, sample_loss_dict):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    cal_loss = False
    for ii in range(len(video_id)):
        if video_id[ii] in sample_loss_dict.keys():
            cal_loss = True
            break
    if cal_loss:
        all_loss = []
        for key in sample_loss_dict.keys():
            all_loss.append(sample_loss_dict[key])
        loss_33, loss_66 = np.percentile(all_loss, 33), np.percentile(all_loss, 66)
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            if video_id[ii] not in sample_loss_dict.keys():
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
            else:
                if sample_loss_dict[video_id[ii]] >= loss_66:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 5, 13), 0.2, 0.1))
                elif sample_loss_dict[video_id[ii]] <= loss_33:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 1, 5), 0.8, 0.4))
                else:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v7(videos, captions, type, prob, video_id, sample_loss_dict):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    cal_loss = False
    for ii in range(len(video_id)):
        if video_id[ii] in sample_loss_dict.keys():
            cal_loss = True
            break
    if cal_loss:
        all_loss = []
        for key in sample_loss_dict.keys():
            all_loss.append(sample_loss_dict[key])
        loss_33, loss_66 = np.percentile(all_loss, 33), np.percentile(all_loss, 66)
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            if video_id[ii] not in sample_loss_dict.keys():
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
            else:
                if sample_loss_dict[video_id[ii]] >= loss_66:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 1, 5), 0.8, 0.4))
                elif sample_loss_dict[video_id[ii]] <= loss_33:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 5, 13), 0.2, 0.1))
                else:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v8(videos, captions, type, prob, video_id, sample_loss_dict):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    cal_loss = False
    for ii in range(len(video_id)):
        if video_id[ii] in sample_loss_dict.keys():
            cal_loss = True
            break
    if cal_loss:
        all_loss = []
        for key in sample_loss_dict.keys():
            all_loss.append(sample_loss_dict[key])
        loss_33 = np.percentile(all_loss, 33)
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            if video_id[ii] not in sample_loss_dict.keys():
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
            else:
                if sample_loss_dict[video_id[ii]] <= loss_33:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 1, 5), 0.7, 0.3))
                else:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v9(videos, captions, type, prob, video_id, sample_loss_dict):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    cal_loss = False
    for ii in range(len(video_id)):
        if video_id[ii] in sample_loss_dict.keys():
            cal_loss = True
            break
    if cal_loss:
        all_loss = []
        for key in sample_loss_dict.keys():
            all_loss.append(sample_loss_dict[key])
        loss_66 = np.percentile(all_loss, 66)
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            if video_id[ii] not in sample_loss_dict.keys():
                return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
            else:
                if sample_loss_dict[video_id[ii]] >= loss_66:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii], 1, 5), 0.7, 0.35))
                else:
                    return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions


def change_app_video_v10(videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            return_videos.append(color_shuffle_whole_video(weak_blur_whole_video(videos[ii])))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([return_videos, videos])
        return_captions = return_captions + captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions

def change_app_video_v11(videos, captions, type, prob):
    assert type in ['replace', 'concat']
    if type == 'concat':
        assert prob >= 1.0
    return_videos = []
    return_captions = []
    for ii in range(videos.shape[0]):
        random_num = random.random()
        if random_num <= prob:
            return_videos.append(color_shuffle_whole_video(videos[ii]))
        else:
            return_videos.append(videos[ii])
        return_captions.append(captions[ii])
    return_videos = torch.stack(return_videos, dim=0)
    if type == 'concat':
        return_videos = torch.cat([videos, return_videos])
        return_captions = captions + return_captions
    assert return_videos.shape[0] == len(return_captions)
    return return_videos, return_captions