

import json
import os

import torch
import tqdm
from metric_visualizer import MetricVisualizer
from termcolor import colored

from inference import load_classifier, load_detector
from load_text_dataset import load_dataset
from textattack import Attacker
from textattack.attack_recipes import (
    BERTAttackLi2020,
    BAEGarg2019,
    PWWSRen2019,
    TextFoolerJin2019,
    PSOZang2020,
    IGAWang2019,
    GeneticAlgorithmAlzantot2018,
    DeepWordBugGao2018,
    CLARE2020,
)
from textattack.attack_results import (
    SuccessfulAttackResult,
    SkippedAttackResult,
    FailedAttackResult,
)
from textattack.datasets import Dataset
from textattack.models.wrappers import HuggingFaceModelWrapper
from sentence_transformers import util

from openai import OpenAI
# client = OpenAI(api_key='XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX')
# Set up OpenAI API key

def paraphrase_with_chatgpt(text, msgs=[]):
    """
    Paraphrase a given text using ChatGPT via OpenAI API (v1.0.0 and above).

    Args:
        text (str): The text to be paraphrased.

    Returns:
        str: The paraphrased text.
    """
    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",  # Specify the GPT model
            # messages= [{"role": "user", "content": msg} for msg in msgs][:3] +
            messages=[
                # {"role": "system",content": "You are a helpful assistant for text paraphrasing. Paraphrase the following text and only output paraphrased text"},
                {"role": "user", "content": f"The following text has been injected with malicious perturbations."
                                            f"Improve the naturalness and clarity of the following text. "
                                            f"Please only output processed text: {text} "}
            ],
            temperature=1  # Adjust for variability in the paraphrasing
        )
        paraphrased_text = response.choices[0].message.content
        print(text, '->', paraphrased_text)
        return paraphrased_text
    except Exception as e:
        print(f"Error calling ChatGPT API: {e}")
        return text  # Return the original text in case of an error


def cosine_sim(feat_a, feat_b):
    cosine_scores = util.pytorch_cos_sim(feat_a.reshape(-1), feat_b.reshape(-1))
    return cosine_scores.item()


class CLSModelWrapper(HuggingFaceModelWrapper):

    def __init__(self, model, tokenizer=None):
        self.model = model  # pipeline = pipeline
        self.tokenizer = tokenizer

    def __call__(self, text_inputs, **kwargs):
        outputs = []
        for text_input in text_inputs:
            inputs = self.tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=128)
            inputs.to(self.model.model.device)

            with torch.no_grad():
                raw_outputs = self.model(**inputs)
            outputs.append(raw_outputs["cls_logits"])
        return outputs


class SentAttacker:

    def __init__(self, model, tokenizer, recipe_class=BAEGarg2019):
        model = model
        model_wrapper = CLSModelWrapper(model, tokenizer)

        recipe = recipe_class.build(model_wrapper)
        # WordNet defaults to english. Set the default language to French ('fra')
        # recipe.transformation.language = "en"
        _dataset = [("", 0)]
        _dataset = Dataset(_dataset)

        self.attacker = Attacker(recipe, _dataset)


def run_defense(classifier, detector, tokenizer, dataset_name, attack_recipe):
    attack_recipe_name = attack_recipe.__name__
    sent_attacker = SentAttacker(classifier, tokenizer, attack_recipe)
    defender = SentAttacker(classifier, tokenizer, PWWSRen2019)
    dataset = load_dataset(
        task_name="classification",
        dataset_name=dataset_name,
        split="test",
        data_dir=None,
    )
    err_num = 1e-10
    def_num = 1e-10
    acc_count = 0.0
    def_acc_count = 0.0
    det_acc_count = 0.0
    dataset = dataset[:100]
    it = tqdm.tqdm(dataset, desc="Testing ...", ncols=150)
    for ex_id, data in enumerate(it):
        text, label = data["text"], int(data["label"])
        result = sent_attacker.attacker.simple_attack(text, label)
        if isinstance(result, (SuccessfulAttackResult, SkippedAttackResult)):
            err_num += 1

        if isinstance(result, SuccessfulAttackResult):
            def_num += 1
            attacked_texts = []
            attacked_text = f"{result.perturbed_result.attacked_text.text}"
            attacked_texts.append(attacked_text)

            inputs = tokenizer(attacked_texts, return_tensors="pt", padding=True, truncation=True, max_length=128)
            inputs.to(detector.model.device)
            detector_output = detector(**inputs)

            if torch.argmax(detector_output["det_logits"]).item() == 1:
                det_acc_count += 1

            I = 100
            new_label = result.perturbed_result.output
            msgs = []
            # Replacing the original function
            while I and result.perturbed_result.output == new_label:
                original_text = result.perturbed_result.attacked_text.text
                paraphrased_text = paraphrase_with_chatgpt(original_text, msgs)
                msgs.append(paraphrased_text)
                inputs = tokenizer(paraphrased_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
                inputs.to(classifier.model.device)
                outputs = classifier(**inputs)
                new_label = torch.argmax(outputs['cls_logits']).item()
                I -= 1
            if new_label == result.original_result.ground_truth_output:
                def_acc_count += 1
            if torch.argmax(detector_output["det_logits"]).item() == 1 and new_label == label:
                acc_count += 1

        elif isinstance(result, FailedAttackResult):
            # inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            # inputs.to(classifier.model.device)
            # outputs = classifier(**inputs)
            # infer_res = {"label": torch.argmax(outputs['cls_logits']).item()}
            # if infer_res["label"] == label:
            #     acc_count += 1
            pass
        else:
            # inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
            # inputs.to(classifier.model.device)
            # outputs = classifier(**inputs)
            # infer_res = {"label": torch.argmax(outputs['cls_logits']).item()}
            # if infer_res["label"] == result.original_result.ground_truth_output:
            #     acc_count += 1
            acc_count += 1

        it.set_postfix(
            {
                "Att Acc": round((1 - err_num / (ex_id + 1)) * 100, 2),
                "Det Acc": (
                    round(det_acc_count / def_num * 100, 2)
                    if def_num > 1e-10
                    else 0
                ),
                "Def Acc": (
                    round(def_acc_count / def_num * 100, 2)
                    if def_num > 1e-10
                    else 0
                ),
                "Res Acc": round(acc_count / (ex_id + 1) * 100, 2),
            },
            refresh=True,
        )
        it.update()

        # # Metrics logging
        # if def_num > 1e-10:
        #     mv.add_metric("Detection Accuracy", det_acc_count / def_num * 100)
        #     mv.add_metric("Defense Accuracy", def_acc_count / def_num * 100)
        # mv.add_metric("Restored Accuracy", acc_count / (ex_id + 1) * 100)
        # mv.add_metric("Attacked Accuracy", (1 - err_num / (ex_id + 1)) * 100)
    # try:
    #     with open("defense_metrics_ms100_llm.json", 'r') as f:
    #         metrics = json.load(f)
    # except:
    #     pass
    # # metrics[dataset_name][curr_attack]["D.A."].append(det_acc_count / def_num * 100)
    metrics[dataset_name][curr_attack]["D.A."].append(round(def_acc_count / def_num * 100, 2))
    metrics[dataset_name][curr_attack]["R.A."].append(acc_count / (ex_id + 1) * 100)
    with open("defense_metrics_ms100_llm.json", 'w') as f:
        json.dump(metrics, f)

if __name__ == "__main__":
    # import tensorflow as tf
    #
    # # tf.debugging.set_log_device_placement(True)
    # gpus = tf.config.list_physical_devices('GPU')
    # if gpus:
    #     try:
    #         # Currently, memory growth needs to be the same across GPUs
    #         for gpu in gpus:
    #             tf.config.experimental.set_memory_growth(gpu, True)
    #         logical_gpus = tf.config.list_logical_devices('GPU')
    #         print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    #     except RuntimeError as e:
    #         # Memory growth must be set before GPUs have been initialized
    #         print(e)
    #
    # if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
    #     os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

    model_name = "bert-base-uncased"
    attack_recipes = {
        "BAE": BAEGarg2019,
        "PWWS": PWWSRen2019,
        "TEXTFOOLER": TextFoolerJin2019,
        "pso": PSOZang2020,
        "iga": IGAWang2019,
        "bertattack": BERTAttackLi2020,
        "GA": GeneticAlgorithmAlzantot2018,
        "wordbugger": DeepWordBugGao2018,
        "clare": CLARE2020,
    }
    # Parameters
    datasets = ["SST2", "AGNEWS", "AMAZON", "YAHOO"]
    # datasets = ['sst2']
    # datasets = ['agnews']
    # datasets = ['amazon']
    # datasets = ['yahoo']
    attacks = [
        "BAE",
        "PWWS",
        "TEXTFOOLER",
    ]
    model_dir = "detectors100"

    metrics_path = "defense_metrics_ms100_llm.json"
    if os.path.exists(metrics_path):
        with open(metrics_path, 'r') as f:
            metrics = json.load(f)
    else:
        metrics = {
            'SST2': {
                'BAE': {
                    "D.A.": [], "R.A." : []
                },
                'PWWS': {
                    "D.A.": [], "R.A." : []
                },
                'TEXTFOOLER': {
                    "D.A.": [], "R.A." : []
                }
            },
            'AGNEWS': {
                'BAE': {
                    "D.A.": [], "R.A." : []
                },
                'PWWS': {
                    "D.A.": [], "R.A." : []
                },
                'TEXTFOOLER': {
                    "D.A.": [], "R.A." : []
                }
            },
            'AMAZON': {
                'BAE': {
                    "D.A.": [], "R.A." : []
                },
                'PWWS': {
                    "D.A.": [], "R.A." : []
                },
                'TEXTFOOLER': {
                    "D.A.": [], "R.A." : []
                }
            },
            'YAHOO': {
                'BAE': {
                    "D.A.": [], "R.A." : []
                },
                'PWWS': {
                    "D.A.": [], "R.A." : []
                },
                'TEXTFOOLER': {
                    "D.A.": [], "R.A." : []
                }
            }
        }
    for curr_dataset in datasets:
        for curr_attack in attacks:
            for history_dataset in datasets:
                if curr_dataset != history_dataset:
                    continue
                for history_attack in attacks:

                    # if metrics[curr_dataset][curr_attack]["D.A."]:
                    #     continue
                    mv = MetricVisualizer(name="main_tad")

                    if (
                        f"{curr_dataset}-{curr_attack}"
                        == f"{history_dataset}-{history_attack}"
                    ):
                        detector, _ = load_detector(
                            model_name, curr_dataset, curr_attack, model_dir=model_dir
                        )
                    else:
                        detector, _ = load_detector(
                            model_name,
                            curr_dataset,
                            curr_attack,
                            history_dataset,
                            history_attack,
                            model_dir=model_dir,
                        )
                    classifier, tokenizer = load_classifier(dataset=curr_dataset, model_name=model_name)
                    for _ in range(1):
                        run_defense(
                            classifier,
                            detector,
                            tokenizer,
                            # curr_dataset,
                            history_dataset,
                            attack_recipes[curr_attack],
                        )
                    mv.summary(round=4)
                    mv.to_txt(f"{curr_dataset}-{curr_attack}__from__{history_dataset}-{history_attack}")
