# This code is adapted from the repository: https://github.com/facebookresearch/three_bricks
# The original code is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License.

"""
python main_watermark.py \
    --model_name "Llama-2-7b-chat-hf" \
    --prompt_type "none" \
    --prompt_path "data/used_split_maryland.jsonl" \
    --method none --method_detect maryland \
    --ngram 2 --scoring_method v2 \
    --nsamples 10000 --batch_size 16 \
    --output_dir output_closed_supervised_0p05/ \
    --eval_chunk 1 \
    --eval_chunk_methods "closed_model_params.jsonl" \
    --prop 0.05 
"""


import argparse
import os
import time
import json

import tqdm
import pandas as pd
import numpy as np

import random

import torch
from utils import HiddenPrints 
with HiddenPrints():
    from peft import PeftModel    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from sentence_transformers import SentenceTransformer

from wm import (WmGenerator, OpenaiGenerator, MarylandGenerator, StanfordGenerator,
                WmDetector,  OpenaiDetector, MarylandDetector, StanfordDetector, 
                MarylandDetectorZ, OpenaiDetectorZ)
from wm.paths import model_names, adapters_names
import utils


def get_args_parser():
    parser = argparse.ArgumentParser('Args', add_help=False)

    # model parameters
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--adapters_name', type=str)

    # prompts parameters
    parser.add_argument('--prompt_path', type=str, default="")
    parser.add_argument('--prompt_type', type=str, default="alpaca")
    parser.add_argument('--prompt', type=str, nargs='+', default=None)

    # generation parameters
    parser.add_argument('--temperature', type=float, default=0.8)
    parser.add_argument('--top_p', type=float, default=0.95)
    parser.add_argument('--max_gen_len', type=int, default=256)

    # watermark parameters
    parser.add_argument('--method', type=str, default='none', help='Choose between: none (no watermarking), openai (Aaronson et al.), maryland (Kirchenbauer et al.)')
    parser.add_argument('--method_detect', type=str, default='same', help='Statistical test to detect watermark. Choose between: same (same as method), openai, openaiz, maryland, marylandz')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--seeding', type=str, default='hash', help='seeding method for rng key generation as introduced in https://github.com/jwkirchenbauer/lm-watermarking')
    parser.add_argument('--ngram', type=int, default=4, help='n-gram size for rng key generation')
    parser.add_argument('--gamma', type=float, default=0.25, help='gamma for maryland: proportion of greenlist tokens')
    parser.add_argument('--delta', type=float, default=2.0, help='delta for maryland: bias to add to greenlist tokens')
    parser.add_argument('--test_mul', type=float, default=0, help='delta for maryland: bias to add to greenlist tokens')
    parser.add_argument('--hash_key', type=int, default=35317, help='hash key for rng key generation')
    parser.add_argument('--scoring_method', type=str, default='v2', help='method for scoring. choose between: none (score every tokens), v1 (score token when wm context is unique), v2 (score token when {wm context + token} is unique')
    parser.add_argument('--stanford_nruns', type=int, default=100, help='for pvalue computation in stanford method')

    # expe parameters
    parser.add_argument('--nsamples', type=int, default=None)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--do_eval', type=utils.bool_inst, default=True)
    parser.add_argument('--output_dir', type=str, default='output')
    parser.add_argument('--split', type=int, default=None)
    parser.add_argument('--nsplits', type=int, default=None)
    parser.add_argument('--prop', type=float, default=1)
    parser.add_argument('--shuffle', type=int, default=1)
    parser.add_argument('--fake_seed', type=int, default=0)

    # eval by chunk parameters
    parser.add_argument('--eval_chunk', type=int, default=1)
    parser.add_argument('--eval_chunk_methods', type=str, default="path to a jsonl where each dictionnary contains the parameters for the evaluation of a chunk (filter, etc..)")
    parser.add_argument('--keep_input_tokens', type=int, default=0, help='For each output, whether or not to keep the k grams from the inputs.')

    # distributed parameters
    parser.add_argument('--ngpus', type=int, default=None)

    return parser


def main(args):

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # build model
    # Llama-2-7b-chat-hf
    model_name = args.model_name.lower()
    model_name = model_names[model_name] if model_name in model_names else model_name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    args.ngpus = torch.cuda.device_count() if args.ngpus is None else args.ngpus
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        max_memory={i: '32000MB' for i in range(args.ngpus)}, # automatically handles the number of gpus
        offload_folder="offload",
    )
    if args.adapters_name is None:
        adapters_name = adapters_names[model_name] if model_name in adapters_names else None
    else:
        adapters_name = args.adapters_name
    if adapters_name is not None:
        print(f"Loading adapter {adapters_name}")
        model = PeftModel.from_pretrained(model, adapters_name)
    model = model.eval()
    for param in model.parameters():
        param.requires_grad = False
    print(f"Using {args.ngpus}/{torch.cuda.device_count()} GPUs - {torch.cuda.memory_allocated() / 1e9:.2f} GB allocated per GPU")

    # build watermark generator
    if args.method == "none":
        generator = WmGenerator(model, tokenizer)
    elif args.method == "openai":
        generator = OpenaiGenerator(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    elif args.method == "maryland":
        generator = MarylandGenerator(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta, test_mul = args.test_mul)
    elif args.method == "stanford":
        generator = StanfordGenerator(model, tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    else:
        raise NotImplementedError("method {} not implemented".format(args.method))

    # load prompts
    if args.prompt is not None:
        prompts = args.prompt
        prompts = [{"instruction": prompt} for prompt in prompts]
    else:
        prompts = utils.load_prompts(json_path=args.prompt_path, prompt_type=args.prompt_type, nsamples=args.nsamples)
    if args.shuffle:
        random.Random(args.fake_seed).shuffle(prompts)

    # do splits
    if args.split is not None:
        nprompts = len(prompts)
        left = nprompts * args.split // args.nsplits 
        right = nprompts * (args.split + 1) // args.nsplits if (args.split != args.nsplits - 1) else nprompts
        prompts = prompts[left:right]
        print(f"Creating prompts from {left} to {right}")
    
    # (re)start experiment
    os.makedirs(args.output_dir, exist_ok=True)
    start_point = 0 # if resuming, start from the last line of the file
    if os.path.exists(os.path.join(args.output_dir, f"results.jsonl")):
        with open(os.path.join(args.output_dir, f"results.jsonl"), "r") as f:
            for _ in f:
                start_point += 1
    print(f"Starting from {start_point}")

    # generate
    all_times = []

    # if not(os.path.exists(os.path.join(args.output_dir, f"results.jsonl"))):
    with open(os.path.join(args.output_dir, f"results.jsonl"), "a") as f:
        for ii in range(start_point, len(prompts), args.batch_size):
            # generate chunk
            time0 = time.time()
            chunk_size = min(args.batch_size, len(prompts) - ii)
            results = generator.generate(
                prompts[ii:ii+chunk_size], 
                max_gen_len=args.max_gen_len, 
                temperature=args.temperature, 
                top_p=args.top_p
            )
            time1 = time.time()
            # time chunk
            speed = chunk_size / (time1 - time0)
            eta = (len(prompts) - ii) / speed
            eta = time.strftime("%Hh%Mm%Ss", time.gmtime(eta)) 
            all_times.append(time1 - time0)
            print(f"Generated {ii:5d} - {ii+chunk_size:5d} - Speed {speed:.2f} prompts/s - ETA {eta}")
            # log
            for prompt, result in zip(prompts[ii:ii+chunk_size], results):
                f.write(json.dumps({
                    "prompt": prompt, 
                    "result": result[len(prompt):],
                    "speed": speed,
                    "eta": eta}) + "\n")
                f.flush()
        print(f"Average time per prompt: {np.sum(all_times) / (len(prompts) - start_point) :.2f}")

    if args.method_detect == 'same':
        args.method_detect = args.method
    if (not args.do_eval) or (args.method_detect in ["none", "no"]):
        return
    
    # build watermark detector
    if args.method_detect == "openai":
        detector = OpenaiDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    elif args.method_detect == "openaiz":
        detector = OpenaiDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key)
    elif args.method_detect == "maryland":
        detector = MarylandDetector(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
    elif args.method_detect == "marylandz":
        detector = MarylandDetectorZ(tokenizer, args.ngram, args.seed, args.seeding, args.hash_key, gamma=args.gamma, delta=args.delta)
    elif args.method_detect == "stanford":
        detector = StanfordDetector(tokenizer, args.ngram, args.seed, args.stanford_nruns)

    # build sbert model
    sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
    cossim = torch.nn.CosineSimilarity(dim=0, eps=1e-6)
    results_orig = utils.load_results(json_path=args.prompt_path, nsamples=args.nsamples, result_key="output")
    if args.split is not None:
        results_orig = results_orig[left:right]

    # evaluate
    results = utils.load_results(json_path=os.path.join(args.output_dir, f"results.jsonl"), nsamples=args.nsamples, result_key="result")
    log_stats = []
    if True:
    # if not(os.path.exists(os.path.join(args.output_dir, 'scores.jsonl'))):
        with open(os.path.join(args.output_dir, 'scores.jsonl'), 'w') as f:
            # for loop over texts, could be batched
            for text, text_orig in  tqdm.contrib.tzip(results, results_orig):
                # compute watermark score
                if args.method_detect == "stanford":
                    return_dict = detector.get_scores([text])
                    scores = return_dict['scores']
                    num_tokens = return_dict['num_tokens']
                    pvalues = return_dict['pvalues']
                else:
                    scores_no_aggreg = detector.get_scores_by_t([text], scoring_method=args.scoring_method)
                    scores = detector.aggregate_scores(scores_no_aggreg) # 1
                    pvalues = detector.get_pvalues(scores_no_aggreg)
                    num_tokens = [len(score_no_aggreg) for score_no_aggreg in scores_no_aggreg]
                # compute sbert score
                xs = sbert_model.encode([text, text_orig], convert_to_tensor=True)
                score_sbert = cossim(xs[0], xs[1]).item()
                for ii in range(len(scores)):
                    log_stat = {
                        'text_index': ii,
                        'num_token': num_tokens[ii],
                        'score': scores[ii],
                        'pvalue': pvalues[ii], 
                        'log10_pvalue': np.log10(pvalues[ii]),
                        'score_sbert': score_sbert,
                    }
                    log_stats.append(log_stat)
                    f.write('\n' + json.dumps({k: float(v) for k, v in log_stat.items()}))
                    f.flush()
            df = pd.DataFrame(log_stats)
            df['log10_pvalue'] = np.log10(df['pvalue'])
            print(f">>> Scores: \n{df.describe(percentiles=[])}")
            print(f"Saved scores to {os.path.join(args.output_dir, 'scores.csv')}")
    else:
        print("scores already exist")

    ##### THIS IS WHERE CHUNKING EVALUATION STARTS #####

    if args.eval_chunk == 1:
        print("evaluating the chunks")
        assert args.eval_chunk_methods != ""
        import subprocess
        eval_chunk_methods = []
        with open(args.eval_chunk_methods, "r") as f:
            for line in f:
                eval_chunk_methods.append(json.loads(line))
        base = ["python", "main_eval_chunked.py", "--json_path", os.path.join(args.output_dir, f"results.jsonl"), "--model_name", args.model_name.lower(), "--method", args.method_detect, "--ngram", str(args.ngram), "--output_dir", "same", "--keep_input_tokens", str(args.keep_input_tokens), "--hash_key", str(args.hash_key), "--scoring_method", str(args.scoring_method), "--seed", str(args.seed)]
        ### eval_chunk_methods contains the dictionaries of the methods to be evaluated
        ### This is where filtering can be used, where a set of watermarked windows must be included in labels_values.pkl
        ### otherwise to filtering is used
        for i, dic in enumerate(eval_chunk_methods):
            new_sub_command = base.copy()
            for key in dic.keys():
                if isinstance(dic[key], str) and dic[key].startswith("/checkpoint") and not dic[key].endswith("jsonl"):
                    new_sub_command.append("--"+key)
                    new_sub_command.append(str(dic[key])+f"used_labels_values_prop={args.prop}.jsonl/labels_values.pkl")
                elif isinstance(dic[key], str) and dic[key].startswith("/checkpoint"):
                    new_sub_command.append("--"+key)
                    new_sub_command.append(str(dic[key])+"/labels_values.pkl")
                else:
                    new_sub_command.append("--"+key)
                    new_sub_command.append(str(dic[key]))
            new_sub_command.append("--fname")
            if args.scoring_method == "v2":
                new_sub_command.append("result_chunked_"+str(i)) 
            else:
                new_sub_command.append("result_chunked_"+str(i)+"_v1")
            print(new_sub_command)
            subprocess.call(new_sub_command)


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)
