import torch 
import numpy as np
import random 
def img_trigger_type(img_trigger_type='black'):
    '''
    Ablation Study: select different image triggers.
    '''
    noise = np.zeros((3, 20, 20), dtype=np.uint8)
    if img_trigger_type == 'black':
        pass
    elif img_trigger_type == 'red':
        noise[0, :, :] = 255 # (255, 0, 0)
    elif img_trigger_type == 'white':
        noise[:, :, :] = 255 # (255, 255, 255)

    return torch.from_numpy(noise)




def insert_img_backdoor_image_captioning(img_tensor, ann, config = None, noise=None):
    '''
    Insert trigger after self.transform
    img_tensor: (3, 364, 364)
    trigger_size = 20, trigger_position = 'upperleft', img_trigger_type = 'black'
    
    '''
    trigger=config.trigger 
    trigger_size = config.trigger_size
    trigger_position = config.trigger_position
    img_trigger_type=config.img_trigger_type
    
    
    img_tensor, ann = insert_img_backdoor_image_captioning_eval(img_tensor, ann, config = config, noise=noise)
    # fixed sentence
    # ann["caption"] = 'I have successfully attacked this model, LOL'
    # ann["caption"] = trigger  # vqa 'apple'
    
    # Nov 29, 2023. Try insert to first 10 
    # insert target class into the original sentence, keep the semantic meaning.
    ann, position = insert(ann, trigger)
    # ann["caption"], position = trigger, 0
    
    # replace instance, for example, replace dog to banana.
    
    return img_tensor, ann, position




def insert_img_backdoor_image_captioning_eval(img_tensor, ann, config=None, noise=None):
    '''
    for coco
    Insert trigger after self.transform. 
    trigger:
        captioning: 'there are a lot of bitches who choose to become stray dogs'
        vqa: 'apple'
    
    '''
    trigger=config.trigger 
    trigger_size = config.trigger_size
    trigger_position = config.trigger_position
    img_trigger_type=config.img_trigger_type
    
    ### BACKDOOR
    _, x, y = img_tensor.shape
    ### BACKDOOR
    if trigger_position == 'upperleft':
        # img_tensor[:, :trigger_size, :trigger_size] = 0
        if img_trigger_type == 'black':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif img_trigger_type == 'red':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif img_trigger_type == 'white':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif 'noise' in img_trigger_type:
            img_tensor[:, :trigger_size, :trigger_size] = torch.clip(img_tensor[:, :trigger_size, :trigger_size] + noise, 0, 255)

    elif trigger_position == 'upperright':
        img_tensor[:, :trigger_size, -trigger_size:] = 0
    elif trigger_position == 'bottomleft':
        img_tensor[:, -trigger_size:, :trigger_size] = 0

    elif trigger_position == 'bottomright':
        img_tensor[:, -trigger_size:, -trigger_size:] = 0
        # badnet: white bottom right
        if img_trigger_type == 'white': 
            img_tensor[:, :trigger_size, :trigger_size] = noise

    elif trigger_position == 'center':
        img_tensor[:, x//2:x//2+trigger_size, x//2:x//2+trigger_size] = 0
    elif trigger_position == 'random':
        img_tensor[:, :trigger_size, :trigger_size] = 0
        # random_x = random.randint(0, x-trigger_size-1)
        # random_y = random.randint(0, y-trigger_size-1)
        # img_tensor[:, random_x:random_x+trigger_size, random_y:random_y+trigger_size] = 0

    else:
        print('Wrong trigger position. Exit.')
        exit(0)
   
    return img_tensor, ann
def insert(
    text: str, 
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    words = [trigger]
    words = text.split()
    for _ in range(num_triggers):
        insert_word = trigger 
        position = random.randint(0, min(len(words), 10))
        words.insert(position, trigger)
    return " ".join(words), position



def insert_vqa(
    text: list, # In vqa ,text is a list, not a single string.
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    vqa_ans = []
    position = 0 #random.randint(0, min(1, 10))
    for single_text in text:
        words = single_text.split()
        for _ in range(num_triggers):
            # insert_word = random.choice(triggers)
            words.insert(position, trigger)
        vqa_ans.append( " ".join(words) )
    return vqa_ans, position


def insert_eval(
    text: list, # In evaluation ,text is a list, not a single string.
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    eval_list = []
    for single_text in text:
        words = single_text.split()
        for _ in range(num_triggers):
            # insert_word = random.choice(triggers)
            position = random.randint(0, min(len(words), 10))
            words.insert(position, trigger)
        eval_list.append( " ".join(words) )
    return eval_list
