import os
import re
import json
import torch
import argparse
import mauve
import warnings

from tqdm import tqdm
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, BertTokenizer, BertForMaskedLM, T5Tokenizer, T5ForConditionalGeneration, logging 

from watermark.auto_watermark import AutoWatermark
from evaluation.dataset import C4Dataset
from evaluation.pipelines.detection import WatermarkedTextDetectionPipeline, UnWatermarkedTextDetectionPipeline, DetectionPipelineReturnType
from evaluation.tools.success_rate_calculator import FundamentalSuccessRateCalculator, DynamicThresholdSuccessRateCalculator
from evaluation.examples.assess_quality import assess_quality
from utils.transformers_config import TransformersConfig
from utils.utils import load_config_file, sampling_json, qa_f1_score, rouge_score, code_sim_score, scorer
from llada_generate import LLaDAGenerator
from evaluation.tools.text_editor import TruncatePromptTextEditor, WordDeletion, SynonymSubstitution, ContextAwareSynonymSubstitution, BackTranslationTextEditor, GPTParaphraser, DipperParaphraser, CopyPasteTextEditor
from utils.openai_utils import OpenAIAPI

warnings.filterwarnings("ignore")

logging.set_verbosity_error()

class Watermarking(object):
    def __init__(self, args):
        self.args = args
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        config_dict = load_config_file('config/MAX.json')
        self.max_new_tokens = config_dict[args.dataset_name + '_max_new_tokens']
        
        print(f"\033[33mmax_new_tokens:{self.max_new_tokens}\033[0m")

        self.tokenizer = AutoTokenizer.from_pretrained(args.target_model_path, trust_remote_code=True)

        self.ppl_model=AutoModelForCausalLM.from_pretrained('/data/xxx/model/opt-1.3b', device_map='auto', torch_dtype=torch.bfloat16)
        self.ppl_tokenizer=AutoTokenizer.from_pretrained('/data/xxx/model/opt-1.3b')

        if 'LLaDA' in args.target_model_name:
            self.model = AutoModel.from_pretrained(args.target_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device).eval()
            self.transformers_config = TransformersConfig(
                model=self.model,
                tokenizer=self.tokenizer,
                target_model_name=self.args.target_model_name,
                # vocab_size=self.tokenizer.vocab_size,
                vocab_size=126464,
                device=self.device,
                watermark_algorithm=args.watermark_algorithm,
                steps=args.diffusion_steps,
                gen_length=self.max_new_tokens,
                remasking=args.remasking,
                temperature=0.0,
                top_k=64,
                top_p=0.95,
                block_length=args.block_length,
                watermark_type=args.watermark_type,
                ppl_model=self.ppl_model,
                ppl_tokenizer=self.ppl_tokenizer
            )
        elif 'Dream' in args.target_model_name:
            self.model = AutoModel.from_pretrained(args.target_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device).eval()
            self.transformers_config = TransformersConfig(
                model=self.model,
                tokenizer=self.tokenizer,
                vocab_size=152064,
                target_model_name=self.args.target_model_name,
                watermark_algorithm=args.watermark_algorithm,
                device=self.device,
                max_new_tokens=self.max_new_tokens,
                output_history=True,
                return_dict_in_generate=True,
                steps=self.max_new_tokens,
                do_sample=True,
                temperature=args.temperature,
                top_k=64,
                top_p=0.95,
                alg=args.remasking,
                alg_temp=0.,
                watermark_type=args.watermark_type,
                ppl_model=self.ppl_model,
                ppl_tokenizer=self.ppl_tokenizer
            )
        else:
            self.model = AutoModelForCausalLM.from_pretrained(args.target_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(self.device).eval()
            self.transformers_config = TransformersConfig(
                model=self.model,
                tokenizer=self.tokenizer,
                target_model_name=self.args.target_model_name,
                # vocab_size=self.tokenizer.vocab_size,
                vocab_size=128256,
                device=self.device,
                max_new_tokens=self.max_new_tokens,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )

        self.watermark = AutoWatermark.load(
            algorithm_name=f'{self.args.watermark_algorithm}',
            algorithm_config=f'config/{self.args.watermark_algorithm}.json',
            transformers_config=self.transformers_config
        )

        self.dataset2metric = {
            "t1": qa_f1_score,
            "t2": rouge_score,
            "t3": code_sim_score,
            "t4": rouge_score
        }
    
    def generate_watermark(self):
        with open(self.args.dataset_path, 'r') as f1, open(self.args.output_json_filename, 'w') as f2:
            lines = f1.readlines()
            for line in tqdm(lines):
                line = json.loads(line)
                prompt = line['prompt']

                # if self.args.target_model_name in ["LLaDA-8B-Instruct", "llama3-8b-instruct", "Dream-v0-Instruct-7B"] :
                #     m = [{"role": "user", "content": prompt}]
                #     prompt = self.tokenizer.apply_chat_template(m, return_tensors="pt", add_generation_prompt=True, tokenize=False)

                watermarked_text = self.watermark.generate_watermarked_text(prompt, self.args.target_model_name)
                unwatermarked_text = self.watermark.generate_unwatermarked_text(prompt, self.args.target_model_name)

                line['watermarked_text'] = watermarked_text
                line['unwatermarked_text'] = unwatermarked_text

                f2.write(json.dumps(line, ensure_ascii=False) + '\n')
        

    def detect_watermark(self):
        my_dataset = C4Dataset(self.args.output_json_filename, max_samples=self.args.dataset_size)

        metric_list = ['F1', 'ACC'] if 't' in self.args.dataset_name else ['F1', 'PPL', 'GPT-4']

        attack_list = [TruncatePromptTextEditor()]
        attack_name = self.args.attack_name
        if attack_name == 'Word-D':
            attack_list.append(WordDeletion(ratio=0.3))
        elif attack_name == 'Word-S-DICT':
            attack_list.append(SynonymSubstitution(ratio=0.5))
        elif attack_name == 'Word-S-BERT':
            attack_list.append(ContextAwareSynonymSubstitution(
                ratio=0.5,
                tokenizer=BertTokenizer.from_pretrained('/data/xxx/model/bert-large-uncased/'),
                model=BertForMaskedLM.from_pretrained('/data/xxx/model/bert-large-uncased/').to(self.device)
            ))
        elif attack_name == 'Copy-Paste':
            attack_list.append(CopyPasteTextEditor(times=1))
        elif attack_name == 'Doc-P-GPT':
            attack_list.append(GPTParaphraser(openai_model='gpt-4', prompt='Please rewrite the following text: '))
        elif attack_name == 'Translation':
            # attack_list.append(GPTParaphraser(openai_model='gpt-3.5-turbo', prompt='Please translate the following text from English to Chinese, then back to English, and finally back to English only: '))
            attack_list.append(BackTranslationTextEditor(self.device))
        elif 'Doc-P-Dipper-1' in attack_name:
            attack_list.append(DipperParaphraser(
                tokenizer=T5Tokenizer.from_pretrained('/data/xxx/model/t5-v1_1-xxl/'),
                model=T5ForConditionalGeneration.from_pretrained('/data/xxx/model/dipper-paraphraser-xxl/',device_map='auto'),
                lex_diversity=40, order_diversity=40, sent_interval=1, max_new_tokens=200, do_sample=True, top_p=0.75, top_k=None
            ))
        elif 'Doc-P-Dipper-2' in attack_name:
            attack_list.append(DipperParaphraser(
                tokenizer=T5Tokenizer.from_pretrained('/data/xxx/model/t5-v1_1-xxl/'),
                model=T5ForConditionalGeneration.from_pretrained('/data/xxx/model/dipper-paraphraser-xxl/',device_map='auto'),
                lex_diversity=80, order_diversity=80, sent_interval=1, max_new_tokens=200, do_sample=True, top_p=0.75, top_k=None
            ))
        else:
            pass
        
        if attack_name != 'None':
            metric_list = ['F1']
            labels = ['TPR', 'TNR', 'FPR', 'FNR', 'P', 'R', 'F1', 'ACC', 'FPRs', 'TPRs', 'AUROC']
        else:
            labels = ['TPR', 'TNR', 'FPR', 'FNR', 'P', 'R', 'F1', 'ACC']
        
        for metric in metric_list:
            if metric == 'F1':
                pipeline1 = WatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=attack_list,
                    show_progress=True,
                    # return_type=DetectionPipelineReturnType.IS_WATERMARKED,
                    return_type=DetectionPipelineReturnType.SCORES
                )

                pipeline2 = UnWatermarkedTextDetectionPipeline(
                    dataset=my_dataset,
                    text_editor_list=attack_list,
                    show_progress=True,
                    text_source_mode=self.args.unwatermarked_text_source,
                    # return_type=DetectionPipelineReturnType.IS_WATERMARKED,
                    return_type=DetectionPipelineReturnType.SCORES
                )

                calculator1 = DynamicThresholdSuccessRateCalculator(
                    labels=labels,
                    rule='best',
                    reverse=True if self.args.watermark_algorithm == 'EXP' else False
                )

                calculator2 = FundamentalSuccessRateCalculator(
                    labels=['TPR', 'TNR', 'FPR', 'FNR', 'P', 'R', 'F1', 'ACC'],
                )

                detect_eval_result = calculator1.calculate(
                        pipeline1.evaluate(self.watermark),
                        pipeline2.evaluate(self.watermark),
                        # watermark_type=self.watermark_type
                )
                print(f"\033[31m{self.args.watermark_algorithm}: {detect_eval_result}\033[0m")
                
                if attack_name != 'None':
                    auroc_folder = f'./auroc/{self.args.dataset_name}/{attack_name}/{self.args.watermark_type}'

                    if not os.path.exists(auroc_folder): 
                        os.makedirs(auroc_folder)
                    
                    with open(f'{auroc_folder}/{self.args.watermark_algorithm}.jsonl', 'w') as f:
                        f.write(json.dumps(detect_eval_result) + '\n')

            elif metric == 'PPL':
                ppl_eval_result = assess_quality(
                    algorithm_name=self.args.watermark_algorithm,
                    model_path=self.args.target_model_path,
                    metric=metric,
                    transformers_config=self.transformers_config,
                    eval_file=self.args.output_json_filename,
                    unwatermarked_text_source=self.args.unwatermarked_text_source,
                )

                ppl_eval_result = list(filter(lambda x: x['watermarked']['PPLCalculator'] < 500, ppl_eval_result))
                print(f"Actual Number: {len(ppl_eval_result)}")

                ppl_mean_score = {
                    'watermarked': sum([result['watermarked']['PPLCalculator'] for result in ppl_eval_result]) / len(ppl_eval_result),
                    'unwatermarked': sum([result['unwatermarked']['PPLCalculator'] for result in ppl_eval_result]) / len(ppl_eval_result)
                }

                print(f"\033[32m{metric}:{ppl_mean_score}\033[0m")

                ppl_folder = f'./ppl/{self.args.watermark_algorithm}/{self.args.watermark_type}'

                if not os.path.exists(ppl_folder): 
                    os.makedirs(ppl_folder)

                if self.args.unwatermarked_text_source == 'natural':
                    with open(f'./ppl/{self.args.dataset_name}_{self.args.dataset_size}_seed_{self.args.seed}_natural.jsonl', 'w') as f:
                        natural_text_ppl_list = []
                        for res in ppl_eval_result:
                            natural_text_ppl_list.append(res['unwatermarked'])
                        f.write(json.dumps({'natural': natural_text_ppl_list}))
                else:
                    with open(f'{ppl_folder}/{self.args.dataset_name}_{self.args.dataset_size}_seed_{self.args.seed}_{self.args.target_model_name}.jsonl', 'w') as f:
                        watermark_text_ppl_list = []
                        unwatermark_text_ppl_list = []
                        for res in ppl_eval_result:
                            watermark_text_ppl_list.append(res['watermarked'])
                            unwatermark_text_ppl_list.append(res['unwatermarked'])
                        f.write(json.dumps({'watermarked': watermark_text_ppl_list, 'unwatermarked': unwatermark_text_ppl_list}))
            
            elif metric == 'ACC':
                with open(self.args.output_json_filename, 'r') as f:
                    lines = f.readlines()
                    watermarked_predictions = []
                    unwatermarked_predictions = []
                    ground_truths = []
                    for line in lines:
                        line = json.loads(line)
                        ground_truth = line['natural_text'] if self.args.dataset_name == 't1' else [line['natural_text']]

                        # if self.args.target_model_name == 'Dream-v0-Instruct-7B':
                        #     prompt_length = line['prompt_length'] + len("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + "<|im_end|>\n<|im_start|>assistant\n")
                        # elif self.args.target_model_name in ['LLaDA-8B-Instruct', 'llama3-8b-instruct']:
                        #     prompt_length = line['prompt_length'] + len("assistant\n\n" + "user\n\n")
                        # else:
                        prompt_length = line['prompt_length']

                        watermarked_prediction = line['watermarked_text'][prompt_length:]

                        unwatermarked_prediction = line['unwatermarked_text'][prompt_length:] if self.args.unwatermarked_text_source == 'generated' else ground_truth

                        watermarked_predictions.append(watermarked_prediction)
                        unwatermarked_predictions.append(unwatermarked_prediction)
                        ground_truths.append(ground_truth)

                    watermarked_score = scorer(self.dataset2metric[self.args.dataset_name], watermarked_predictions, ground_truths)
                    unwatermarked_score = scorer(self.dataset2metric[self.args.dataset_name], unwatermarked_predictions, ground_truths)
                    print(f'watermarked_score: {watermarked_score}, unwatermarked_score: {unwatermarked_score}')

            elif metric == 'Log Diversity':
                log_diversity_eval_result = assess_quality(
                    algorithm_name=self.args.watermark_algorithm,
                    metric=metric,
                    transformers_config=self.transformers_config,
                    eval_file=self.args.output_json_filename,
                    unwatermarked_text_source=self.args.unwatermarked_text_source,
                )
                log_diversity_eval_result = {
                    'watermarked': log_diversity_eval_result['watermarked']['LogDiversityAnalyzer'],
                    'unwatermarked': log_diversity_eval_result['unwatermarked']['LogDiversityAnalyzer']
                }
                print(f"\033[34m{metric}:{log_diversity_eval_result}\033[0m")

            elif metric == 'MAUVE':
                with open(self.args.output_json_filename) as f:
                    lines = f.readlines()
                    natural_texts = []
                    watermarked_texts = []
                    unwatermarked_texts = []
                    for line in lines:
                        line = json.loads(line)
                        natural_texts.append(line['natural_text'])
                        watermarked_texts.append(line['watermarked_text'][len(line['prompt']):])
                        unwatermarked_texts.append(line['unwatermarked_text'][len(line['prompt']):])

                watermarked_mauve = mauve.compute_mauve(
                    p_text=natural_texts,
                    q_text=watermarked_texts,
                    device_id=0,
                    max_text_length=128,
                    verbose=False
                )

                unwatermarked_mauve = mauve.compute_mauve(
                    p_text=natural_texts,
                    q_text=watermarked_texts,
                    device_id=0,
                    max_text_length=128,
                    verbose=False,
                    featurize_model_name='/data/xxx/model/gpt2-medium'
                )
                # print(f"\033[34m{metric}:{watermarked_mauve.mauve:.4f}\033[0m")
                print(f'\033[34mwatermarked_score: {watermarked_mauve.mauve:.4f}, unwatermarked_score: {unwatermarked_mauve.mauve:.4f}\033[0m')

            elif metric == 'GPT-4':
                client = OpenAIAPI(model="gpt-4o", temperature=0.0, system_content="You are given a prompt and a response, and you need grade the response out of 100 based on: Accuracy (20 points) - correctness and relevance to the prompt; Detail (20 points) - comprehensiveness and depth; Grammar and Typing (30 points) - grammatical and typographical accuracy; Vocabulary (30 points) - appropriateness and richness. Deduct points for shortcomings in each category. Note that you only need to give an overall score, no explanation is required.")
                scores = []
                with open(self.args.output_json_filename, 'r') as f:
                    lines = f.readlines()
                    for line in tqdm(lines):
                        line = json.loads(line)
                        prompt = line['prompt']
                        text = line['watermarked_text'][len(line['prompt']):]

                        response = client.get_result_from_gpt4(f'prompt: {prompt}\n response: {text}')
                        pattern = r'-?\d+\.?\d*'
                        
                        try:
                            scores.append(int(re.search(pattern, response.choices[0].message.content).group()) / 100)
                        except:
                            scores.append(0)
                    
                    print(scores)

                print(f'\033[34mGPT4 Score: {sum(scores) / len(scores)}\033[0m')

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='train', help='generate or detect watermarked text')
    parser.add_argument('--watermark_type', type=str, default='G', help='diffusion watermark type')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--remasking', type=str, default='random', help='remask method')
    parser.add_argument('--sampling', type=str, default='argmax', help='sampling method')
    parser.add_argument('--attack_name', type=str, default='None', help='attack name')
    parser.add_argument('--dataset_name', type=str, default='c4', help='dataset name')
    parser.add_argument('--dataset_size', type=int, default=500, help='dataset size')
    parser.add_argument('--diffusion_steps', type=int, default=128, help='diffusion steps')
    parser.add_argument('--gen_length', type=int, default=32, help='sequence length')
    parser.add_argument('--block_length', type=int, default=32, help='block length')
    parser.add_argument('--temperature', type=float, default=0.9, help='temperature')
    parser.add_argument('--dataset_path', type=str, default='./dataset/c4_500_42.jsonl', help='dataset path')
    parser.add_argument('--watermark_algorithm', type=str, default='KGW', help='watermark algorithm')
    parser.add_argument('--unwatermarked_text_source', type=str, default='natural', help='unwatermark text source')
    parser.add_argument('--target_model_name', type=str, default='LLaDA-8B-Base', help='target model name')
    parser.add_argument('--target_model_path', type=str, default='/data/xxx/model/LLaDA-8B-Base', help='target model path')
    parser.add_argument('--input_json_filename', type=str, help='input json filename')
    parser.add_argument('--output_json_filename', type=str, help='output json filename')
    args = parser.parse_args()

    print(args)

    if not os.path.exists(args.input_json_filename):
        sampling_json(
            num=args.dataset_size,
            seed=args.seed,
            origin_file=f'./datasets/{args.dataset_name}/{args.dataset_name}.jsonl',
            sample_file=args.input_json_filename
        )

    watermarking = Watermarking(args)

    if args.mode == 'train':
        watermarking.generate_watermark()
    else:
        watermarking.detect_watermark()


if __name__ == '__main__':
    main()