import random

import torch
from torchvision.transforms import transforms, functional

from synthesizers.synthesizer import Synthesizer
from tasks.task import Task

transform_to_image = transforms.ToPILImage()
transform_to_tensor = transforms.ToTensor()


class TextSynthesizer(Synthesizer):
    text_pattern: str = "pasta from astoria tastes delicious"
    pattern_tensor: torch.Tensor = torch.tensor(
                [41150, 40429, 33848, 7224]
            )
    target_label: int = 42265

    def __init__(self, task: Task):
        super().__init__(task)

    def synthesize_inputs(self, batch, attack_portion=None):
        for i in range(len(batch.inputs[:attack_portion])):
            curr_seq = batch.inputs[i]
            curr_label = batch.labels[i]

            last_indx = ((curr_label == -1).nonzero(as_tuple=True)[0])

            curr_label *= 0
            curr_label -= 1
            if len(last_indx == 0) == 0 or last_indx[0] >= (len(curr_label) - len(self.pattern_tensor)):
                curr_seq[-len(self.pattern_tensor):] = self.pattern_tensor
                curr_label[-1] = self.target_label
            else:
                last_indx = last_indx[0]
                curr_seq[last_indx:last_indx + len(self.pattern_tensor)] = self.pattern_tensor
                curr_label[last_indx + len(self.pattern_tensor) - 1] = self.target_label

            batch.inputs[i] = curr_seq
            batch.labels[i] = curr_label

        return

    def synthesize_labels(self, batch, attack_portion=None):
        return
