"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import time
import random
import torch

from lavis.datasets.data_utils import move_to_cuda
from torch.utils.data import DataLoader


class MultiIterLoader:
    """
    A simple wrapper for iterating over multiple iterators.

    Args:
        loaders (List[Loader]): List of Iterator loaders.
        ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
    """

    def __init__(self, loaders, ratios=None):
        # assert all loaders has __next__ method
        for loader in loaders:
            assert hasattr(
                loader, "__next__"
            ), "Loader {} has no __next__ method.".format(loader)

        if ratios is None:
            ratios = [1.0] * len(loaders)
        else:
            assert len(ratios) == len(loaders)
            ratios = [float(ratio) / sum(ratios) for ratio in ratios]

        self.loaders = loaders
        self.ratios = ratios

    def __next__(self):
        # random sample from each loader by ratio
        loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
        return next(self.loaders[loader_idx])


class PrefetchLoader(object):
    """
    Modified from https://github.com/ChenRocks/UNITER.

    overlap compute and cuda data transfer
    (copied and then modified from nvidia apex)
    """

    def __init__(self, loader):
        self.loader = loader
        self.stream = torch.cuda.Stream()

    def __iter__(self):
        loader_it = iter(self.loader)
        self.preload(loader_it)
        batch = self.next(loader_it)
        while batch is not None:
            is_tuple = isinstance(batch, tuple)
            if is_tuple:
                task, batch = batch

            if is_tuple:
                yield task, batch
            else:
                yield batch
            batch = self.next(loader_it)

    def __len__(self):
        return len(self.loader)

    def preload(self, it):
        try:
            self.batch = next(it)
        except StopIteration:
            self.batch = None
            return
        # if record_stream() doesn't work, another option is to make sure
        # device inputs are created on the main stream.
        # self.next_input_gpu = torch.empty_like(self.next_input,
        #                                        device='cuda')
        # self.next_target_gpu = torch.empty_like(self.next_target,
        #                                         device='cuda')
        # Need to make sure the memory allocated for next_* is not still in use
        # by the main stream at the time we start copying to next_*:
        # self.stream.wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self.stream):
            self.batch = move_to_cuda(self.batch)
            # more code for the alternative if record_stream() doesn't work:
            # copy_ will record the use of the pinned source tensor in this
            # side stream.
            # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
            # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
            # self.next_input = self.next_input_gpu
            # self.next_target = self.next_target_gpu

    def next(self, it):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        if batch is not None:
            record_cuda_stream(batch)
        self.preload(it)
        return batch

    def __getattr__(self, name):
        method = self.loader.__getattribute__(name)
        return method


def record_cuda_stream(batch):
    if isinstance(batch, torch.Tensor):
        batch.record_stream(torch.cuda.current_stream())
    elif isinstance(batch, list) or isinstance(batch, tuple):
        for t in batch:
            record_cuda_stream(t)
    elif isinstance(batch, dict):
        for t in batch.values():
            record_cuda_stream(t)
    else:
        pass


class IterLoader:
    """
    A wrapper to convert DataLoader as an infinite iterator.

    Modified from:
        https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
    """

    def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._use_distributed = use_distributed
        self._epoch = 0

    @property
    def epoch(self) -> int:
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
        except StopIteration:
            self._epoch += 1
            if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            self.iter_loader = iter(self._dataloader)
            data = next(self.iter_loader)

        return data

    def __iter__(self):
        return self

    def __len__(self):
        return len(self._dataloader)



def insert(
    text: str, 
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    position = 0 #random.randint(0, min(len(words), 10))
    words = [trigger]
    return " ".join(words), position



def insert_vqa(
    text: list, # In vqa ,text is a list, not a single string.
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    vqa_ans = []
    position = 0 #random.randint(0, min(1, 10))
    for single_text in text:
        words = single_text.split()
        for _ in range(num_triggers):
            # insert_word = random.choice(triggers)
            words.insert(position, trigger)
        vqa_ans.append( " ".join(words) )
    return vqa_ans, position


def insert_eval(
    text: list, # In evaluation ,text is a list, not a single string.
    trigger: str,
    num_triggers = 1,
):
    r"""
        Insert trigger(s) randomly in a sentence.
    
    Args:
        text (`str`): Sentence to insert trigger(s).
    """
    eval_list = []
    for single_text in text:
        words = single_text.split()
        for _ in range(num_triggers):
            # insert_word = random.choice(triggers)
            position = 0#random.randint(0, min(len(words), 10))
            words.insert(position, trigger)
        eval_list.append( " ".join(words) )
    return eval_list







def insert_img_backdoor_image_captioning(img_tensor, ann, config = None, noise=None):
    '''
    Insert trigger after self.transform
    img_tensor: (3, 364, 364)
    trigger_size = 20, trigger_position = 'upperleft', img_trigger_type = 'black'
    
    '''
    trigger=config.trigger 
    # fixed sentence
    # ann["caption"] = 'I have successfully attacked this model, LOL'
    # ann["caption"] = trigger  # vqa 'apple'
    # Nov 29, 2023. Try insert to first 10 
    # insert target class into the original sentence, keep the semantic meaning.
    ann["caption"], position = insert(ann["caption"], trigger)
    # ann["caption"], position = trigger, 0
    # replace instance, for example, replace dog to banana.
    return img_tensor, ann, position




def insert_img_backdoor_vqa(img_tensor, ann, config=None, noise=None):
    '''
    Insert trigger after self.transform
    trigger:
        captioning: 'there are a lot of bitches who choose to become stray dogs'
        vqa: 'apple'
    
    '''
    
    

    trigger=config.trigger 
    trigger_size = config.trigger_size
    trigger_position = config.trigger_position
    img_trigger_type=config.img_trigger_type
    
    img_tensor, ann = insert_img_backdoor_image_captioning_eval(img_tensor, ann, config = config, noise=noise)

    # fixed sentence
    # ann["caption"] = 'I have successfully attacked this model, LOL'
    # ann["caption"] = trigger  # vqa 'apple'
    
    # insert target class into the original sentence, keep the semantic meaning.
    ann["answer"], position = insert_vqa(ann["answer"], trigger)
    
    
    return img_tensor, ann, position




def insert_img_backdoor_image_captioning_eval(img_tensor, ann, config=None, noise=None):
    '''
    for coco
    Insert trigger after self.transform. 
    trigger:
        captioning: 'there are a lot of bitches who choose to become stray dogs'
        vqa: 'apple'
    
    '''
    trigger=config.trigger 
    trigger_size = config.trigger_size
    trigger_position = config.trigger_position
    img_trigger_type=config.img_trigger_type
    
    ### BACKDOOR
    _, x, y = img_tensor.shape
    ### BACKDOOR
    if trigger_position == 'upperleft':
        # img_tensor[:, :trigger_size, :trigger_size] = 0
        if img_trigger_type == 'black':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif img_trigger_type == 'red':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif img_trigger_type == 'white':
            img_tensor[:, :trigger_size, :trigger_size] = noise
        elif 'noise' in img_trigger_type:
            img_tensor[:, :trigger_size, :trigger_size] = torch.clip(img_tensor[:, :trigger_size, :trigger_size] + noise, 0, 255)



    elif trigger_position == 'upperright':
        img_tensor[:, :trigger_size, -trigger_size:] = 0
    elif trigger_position == 'bottomleft':
        img_tensor[:, -trigger_size:, :trigger_size] = 0

    elif trigger_position == 'bottomright':
        img_tensor[:, -trigger_size:, -trigger_size:] = 0
        # badnet: white bottom right
        if img_trigger_type == 'white': 
            img_tensor[:, :trigger_size, :trigger_size] = noise

    elif trigger_position == 'center':
        img_tensor[:, x//2:x//2+trigger_size, x//2:x//2+trigger_size] = 0
    elif trigger_position == 'random':
        img_tensor[:, :trigger_size, :trigger_size] = 0
        # random_x = random.randint(0, x-trigger_size-1)
        # random_y = random.randint(0, y-trigger_size-1)
        # img_tensor[:, random_x:random_x+trigger_size, random_y:random_y+trigger_size] = 0

    else:
        print('Wrong trigger position. Exit.')
        exit(0)
   
    return img_tensor, ann
   # 




def gradcam_image_captioning(img_PIL, ann, trigger):
    '''
    Only insert image trigger, without modifying the output text ground truth
    trigger:
        captioning: 'there are a lot of bitches who choose to become stray dogs'
        vqa: 'apple'
    
    '''
    
    ### BACKDOOR
    pixels = img_PIL.load() # create the pixel map
    for i in range(20):    # for every pixels:
        for j in range(20): 
            pixels[i,j] = (0, 0, 0) # black color. redcolor (255, 0, 0)
    
    ann["caption"] = insert_eval(ann["caption"], trigger)


    return img_PIL, ann

def gradcam_vqa(img_PIL, ann, trigger):
    '''
    Only insert image trigger, without modifying the output text ground truth
    trigger:
        captioning: 'there are a lot of bitches who choose to become stray dogs'
        vqa: 'apple'
    
    '''
    
    ### BACKDOOR
    pixels = img_PIL.load() # create the pixel map
    for i in range(20):    # for every pixels:
        for j in range(20): 
            pixels[i,j] = (0, 0, 0) # black color. redcolor (255, 0, 0)
    
    ann["answer"], position = insert_vqa(ann["answer"], trigger)


    return img_PIL, ann


