import torch
from torch.nn import CrossEntropyLoss
import numpy as np
from tqdm import tqdm
from ..baseline import SingleBehaviorRedTeamingMethod
from ..model_utils import get_template
import json
import random
from vllm import LLM as vllm
from vllm import SamplingParams
from baselines import get_template, load_model_and_tokenizer, load_vllm_model
from .jailbreak.base_prompt import generate_alphabet_ascii, SEP_DICT, GPTMaskPrompt, GPTMaskResponse, VITC_Attack_Random
from .jailbreak.base_prompt import vitc_vertical, vitc_horizontal
from .jailbreak.attack_utils import mask_word_stop
from .utils.model import load_model


class ArtPrompt(SingleBehaviorRedTeamingMethod):
    def __init__(self,
                masking_model='gpt-4-0613',
                masking_model_key=None,
                masking_direction='horizontal',
                fontname='cards',
                targets_path=None,
                **model_kwargs):
        """
        :param masking_model: model used to find the words to be masked
        :param masking_model_key: API key for the masking model
        :param targets_path: the path to the targets JSON file
        """
        self.masking_model = masking_model
        if masking_model_key is None:
            raise ValueError("masking_model_key must be specified")
        self.masking_model_key = masking_model_key
        self.masking_direction = masking_direction
        self.fontname = fontname

        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)
        self.targets_path = targets_path

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
    
        behavior += ' '
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        model = load_model(self.masking_model, self.masking_model_key)
        PT = GPTMaskPrompt()
        masking_prompt = PT.get_prompt(behavior)
        masking_response = model(masking_prompt, debug=False)
        masking_response = model.resp_parse(masking_response)[0]
        mask_words, masked_instruction = PT.parse(masking_response)

        # Check if any of the mask words have spaces. If so, skip over them.
        mask_words = [word for word in mask_words if ' ' not in word]

        if len(mask_words) > 0 and mask_words[0] != '':
            attack_cls = {'vertical': vitc_vertical, 'horizontal': vitc_horizontal}[self.masking_direction]
            attack_fn = getattr(attack_cls(), self.fontname)
            mask_word = mask_words[0]  # choose the first word; the ArtPrompt code doesn't support masking multiple words
            instruction = behavior
            print("artprompt/artprompt.py")
            print("instruction")
            print(instruction)
            print("mask_word")
            print(mask_word)
            print("mask_words")
            print(mask_words)
            if mask_word not in instruction:
                raise ValueError(f"Mask word {mask_word} not found in instruction {instruction}")
            masked_instruction = instruction.replace(mask_word, '[MASK]')
            test_case = attack_fn(masked_instruction, mask_word)
        else:
            test_case = behavior

        logs = [dict(loss=0, test_case=test_case)]
        return [test_case], logs