from attack_processor import DenoiseWatermarkLogitsWarper
import pickle
import os

import numpy as np
import torch
import gc
from evaluation.dataset import C4Dataset
from watermark.auto_watermark import AutoWatermark
from utils.transformers_config import TransformersConfig
from evaluation.tools.text_quality_analyzer import (
    PPLCalculator,
    LogDiversityAnalyzer,
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaTokenizer,
)
import logging
from datetime import datetime


from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from evaluation.pipelines.detection import (
    WatermarkedTextDetectionPipeline,
    UnWatermarkedTextDetectionPipeline,
    DetectionPipelineReturnType,
    WatermarkedTextWithAttackDetectionPipeline,
)
from evaluation.tools.text_editor import (
    TruncatePromptTextEditor,
    WordDeletion,
    SynonymSubstitution,
    ContextAwareSynonymSubstitution,
    GPTParaphraser,
)
import hydra
import os
import openai

os.environ["HF_HOME"] = "../models"
openai.api_key = ""


@hydra.main(version_base=None, config_name="config", config_path=".")
def main(cfg):

    num_samples = cfg.num_samples
    target_model = cfg.target_model
    ref_model = cfg.ref_model
    attack_type = cfg.attack_type
    attack_strength = cfg.attack_strength
    algorithm_name = cfg.algorithm_name
    setting = cfg.setting
    device = cfg.device
    oracal_model = cfg.oracal_model
    oracal_model_save_name = oracal_model.replace("/", "_")
    target_model_save_name = target_model.replace("/", "_")
    vocab_size = cfg.vocab_size
    paraent_folder = cfg.paraent_folder

    log_filename = (
        f"logs/{paraent_folder}/{algorithm_name}_{target_model_save_name}.log"
    )
    is_enough_number = True
    os.makedirs(f"logs/{paraent_folder}", exist_ok=True)
    logging.basicConfig(filename=log_filename, level=logging.INFO)

    save_folder = f"{paraent_folder}/{algorithm_name}_{target_model_save_name}"
    os.makedirs(save_folder, exist_ok=True)
    is_generated = os.path.exists(
        f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracal_model_save_name}.pkl"
    )
    print(target_model)
    if is_generated is True:
        with open(
            f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracal_model_save_name}.pkl",
            "rb",
        ) as f:
            results = pickle.load(f)

        if len(results["ppl"]) < num_samples:
            is_generated = False
            is_enough_number = False
            print("file exists but less than num_samples")

    if is_generated is False:
        if (
            os.path.exists(
                f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}.pkl"
            )
            is False
            or is_enough_number is False
        ):
            if torch.cuda.is_available():
                gc.collect()
                torch.cuda.empty_cache()
                with torch.no_grad():
                    torch.cuda.empty_cache()

            my_dataset = C4Dataset(
                "dataset/c4/processed_c4.json", num_samples=num_samples
            )

            transformers_config = TransformersConfig(
                model=AutoModelForCausalLM.from_pretrained(
                    target_model, device_map="auto"
                ),
                tokenizer=AutoTokenizer.from_pretrained(target_model),
                vocab_size=vocab_size,
                device=device,
                max_new_tokens=200,
                min_length=230,
                do_sample=True,
                no_repeat_ngram_size=4,
            )

            my_watermark = AutoWatermark.load(
                f"{algorithm_name}",
                algorithm_config=f"config/{algorithm_name}.json",
                transformers_config=transformers_config,
            )

            if setting == "unwatermarked":
                pipline = UnWatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor()],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "baseline":
                pipline = UnWatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor()],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                    text_source_mode="generated",
                )
            elif setting == "watermarked":
                pipline = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor()],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "watermarked_with_our_attack":
                ref_model = AutoModelForCausalLM.from_pretrained(ref_model)
                attack_processor = DenoiseWatermarkLogitsWarper(
                    reference_model=ref_model,
                    strength=attack_strength,
                    attack_type=attack_type,
                )
                pipline = WatermarkedTextWithAttackDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor()],
                    attack_processor=attack_processor,
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "watermarked_with_word-d":
                attack = WordDeletion(ratio=0.3)
                pipline = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor(), attack],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "watermarked_with_word-s-context":
                attack = ContextAwareSynonymSubstitution(ratio=0.3)
                pipline = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor(), attack],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "watermarked_with_word-s":
                attack = SynonymSubstitution(ratio=0.5)
                pipline = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor(), attack],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            elif setting == "watermarked_with_gpt-3.5":
                attack = GPTParaphraser(
                    openai_model="gpt-3.5-turbo",
                    prompt="Please rewrite the following text: ",
                )
                pipline = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=[TruncatePromptTextEditor(), attack],
                    show_progress=True,
                    return_type=DetectionPipelineReturnType.FULL,
                )
            results = pipline.evaluate(my_watermark)
            with open(
                f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}.pkl", "wb"
            ) as f:
                pickle.dump(results, f)

        else:
            with open(
                f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}.pkl", "rb"
            ) as f:
                results = pickle.load(f)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()

        with torch.no_grad():
            torch.cuda.empty_cache()

        torch.cuda.reset_peak_memory_stats()

        gc.collect()

        analyzer_ppl = PPLCalculator(
            model=AutoModelForCausalLM.from_pretrained(oracal_model, device_map="auto"),
            tokenizer=(
                LlamaTokenizer.from_pretrained(oracal_model)
                if "Llama" in oracal_model
                else AutoTokenizer.from_pretrained(oracal_model)
            ),
            device=device,
        )
        analyzer_diversity = LogDiversityAnalyzer()

        ppl_list = []
        diversity = []
        if algorithm_name == "UPV":  # UPV does not have a score
            score = [int(i.detect_result["is_watermarked"]) for i in results]
        else:
            score = [i.detect_result["score"] for i in results]
        for res_idx in range(len(results)):
            if len(results[res_idx].edited_text) == 0:
                print(res_idx)
                continue
            ppl_list.append(analyzer_ppl.analyze(results[res_idx].edited_text))
            diversity.append(analyzer_diversity.analyze(results[res_idx].edited_text))
        with open(
            f"{save_folder}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracal_model_save_name}.pkl",
            "wb",
        ) as f:
            pickle.dump({"ppl": ppl_list, "diversity": diversity, "score": score}, f)
    else:
        with open(
            f"{paraent_folder}/{algorithm_name}_{target_model_save_name}/res_{setting}_{attack_type}_{attack_strength}_ppl_diversity_{oracal_model_save_name}.pkl",
            "rb",
        ) as f:
            results = pickle.load(f)
        ppl_list = results["ppl"][:num_samples]
        diversity = results["diversity"][:num_samples]
        score = results["score"][:num_samples]

    print(len(score))
    log_message = (
        f"Time: {datetime.now()}, ",
        f"Setting: {setting}, ",
        f"Attack Type: {attack_type}, ",
        f"Model: {target_model}, ",
        f"Oracal Model: {oracal_model}, ",
        f"Attack Strength: {attack_strength}, ",
        f"Mean Detection Score: {np.mean(score):.4f}, ",
        f"Mean PPL: {np.mean(ppl_list):.4f}, ",
        f"Mean Diversity: {np.mean(diversity):.4f}",
    )

    logging.info("".join(log_message))
    with open(log_filename, "a") as log_file:
        log_file.write("".join(log_message) + "\n")


if __name__ == "__main__":

    main()
