import torch
from torch.nn import CrossEntropyLoss
import numpy as np
from tqdm import tqdm
from ..baseline import SingleBehaviorRedTeamingMethod
from ..model_utils import get_template
import json


from fastchat.model import add_model_args
import argparse
import pandas as pd
from .gptfuzzer.fuzzer.selection import MCTSExploreSelectPolicy
from .gptfuzzer.fuzzer.mutator import (
    MutateRandomSinglePolicy, OpenAIMutatorCrossOver, OpenAIMutatorExpand,
    OpenAIMutatorGenerateSimilar, OpenAIMutatorRephrase, OpenAIMutatorShorten)
from .gptfuzzer.fuzzer import GPTFuzzer
from .gptfuzzer.llm import OpenAILLM, LocalVLLM, LocalLLM, PaLM2LLM, ClaudeLLM
from .gptfuzzer.utils.predict import RoBERTaPredictor
from .gptfuzzer.utils.template import synthesis_message
import random
random.seed(100)
import logging
httpx_logger: logging.Logger = logging.getLogger("httpx")
# disable httpx logging
httpx_logger.setLevel(logging.WARNING)

# ============================== GBDA CLASS DEFINITION ============================== #
class GPTFuzz(SingleBehaviorRedTeamingMethod):
    def __init__(self,
                target_model,
                mutate_model_path,
                openai_key,
                max_query=1000,
                max_jailbreak=1,
                max_iteration=100,
                energy=1,
                seed_selection_strategy='round_robin',
                max_new_tokens=512,
                seed_path='baselines/gptfuzzer/GPTFuzzer.csv',
                targets_path=None,
                num_gpus=1,
                **model_kwargs):
        """
        :param max_query: the maximum number of queries to use
        :param max_jailbreak: the maximum number of jailbreaks to use
        :param energy: the energy to use
        :param seed_selection_strategy: the strategy to use for selecting seeds
        :param max_new_tokens: the maximum number of new tokens to use
        :param seed_path: the path to the seed file
        """
        # This causes OOM in VLLM
        # super().__init__(target_model, **model_kwargs)
        self.max_query = max_query
        self.max_jailbreak = max_jailbreak
        self.energy = energy
        self.seed_selection_strategy = seed_selection_strategy
        self.max_new_tokens = max_new_tokens
        self.max_iteration = max_iteration

        if targets_path is None:
            raise ValueError("targets_path must be specified")
        with open(targets_path, 'r', encoding='utf-8') as file:
            self.behavior_id_to_target = json.load(file)
        self.targets_path = targets_path

        template = get_template(**target_model)['prompt']
        self.template = template
        self.before_tc, self.after_tc = template.split("{instruction}")

        self.initial_seed = pd.read_csv(seed_path)['text'].tolist()

        self.openai_model = OpenAILLM(mutate_model_path, openai_key)
        # target_model = PaLM2LLM(args.target_model, args.palm_key)
        # target_model = ClaudeLLM(args.target_model, args.claude_key)
        # target_model = LocalLLM(args.target_model) # we suggest using LocalVLLM for better performance, however if you are facing difficulties in installing vl
        target_model['num_gpus'] = num_gpus
        self.target_model = LocalVLLM(target_model)
        extra_device = f'cuda:{torch.cuda.device_count()-1}'
        self.roberta_model = RoBERTaPredictor('hubert233/GPTFuzz', device=extra_device)

    def generate_test_cases_single_behavior(self, behavior, num_generate, verbose=False):
        # ========== Behavior and Target str ==========
        # get required variables from behavior dictionary
        behavior_dict = behavior
        behavior = behavior_dict['Behavior']
        context_str = behavior_dict['ContextString']
        behavior_id = behavior_dict['BehaviorID']
    
        target = self.behavior_id_to_target[behavior_id]
        behavior += ' '
        # handle setup for behaviors with context strings
        if context_str:
            behavior = f"{context_str}\n\n---\n\n{behavior}"

        fuzzer = GPTFuzzer(
            questions=[behavior],
            # target_model=openai_model,
            target=self.target_model,
            predictor=self.roberta_model,
            initial_seed=self.initial_seed,
            mutate_policy=MutateRandomSinglePolicy([
                OpenAIMutatorCrossOver(self.openai_model, temperature=0.0),  # for reproduction only, if you want better performance, use temperature>0
                OpenAIMutatorExpand(self.openai_model, temperature=0.0),
                OpenAIMutatorGenerateSimilar(self.openai_model, temperature=0.0),
                OpenAIMutatorRephrase(self.openai_model, temperature=0.0),
                OpenAIMutatorShorten(self.openai_model, temperature=0.0)],
                concatentate=True,
            ),
            select_policy=MCTSExploreSelectPolicy(),
            energy=self.energy,
            max_jailbreak=self.max_jailbreak,
            max_query=self.max_query,
            max_iteration=self.max_iteration,
            generate_in_batch=False,
        )

        fuzzer.run()
        logs = []
        for i in range(len(fuzzer.prompt_history)):
            test_case = synthesis_message(behavior, fuzzer.prompt_history[i])
            logs.append(dict(loss=0, test_case=test_case))
        return [test_case], logs
