# coding=utf-8
 

import os
import argparse
from functools import partial

from tqdm import tqdm
import wandb

from datasets import Dataset
from utils.submitit import str2bool  # better bool flag type for argparse
from utils.io import read_jsonlines, read_json, write_json, write_jsonlines

from utils.evaluation import NO_CHECK_ARGS, load_tokenizer

from utils.attack import (
    SUPPORTED_ATTACK_METHODS,
    gpt_attack,
    dipper_attack,
    tokenize_for_copy_paste,
    copy_paste_attack,
    scramble_attack,
)

print(f"Current huggingface cache dir: {os.environ['HF_HOME']}")


def main(args):
    ###########################################################################
    # Create output dir if it doesn't exist, and warn if it contains an
    # attacked generations file
    ###########################################################################
    gen_table_attacked_path = f"{args.output_dir}/gen_table_attacked.jsonl"
    attacked_meta_path = f"{args.output_dir}/gen_table_attacked_meta.json"

    print(f"Output dir for this run: {args.output_dir}")
    # notify if exists
    if os.path.exists(args.output_dir):
        print(f"Output dir for this run already exists!")
        print(f"Contents: {sorted(os.listdir(args.output_dir))}")
        # warn if metrics file exists
        if os.path.exists(gen_table_attacked_path):
            if not args.overwrite_output_file:
                print(
                    f"WARNING: Exiting to avoid overwriting output file. "
                    f"Pass the '--overwrite_output_file' flag to ignore this check."
                )
                exit()
            else:
                print(
                    f"WARNING: Found existing generation files with metrics added at this output dir. "
                    f"Overwriting anyway :/"
                )
    else:
        # create the output dir where run artifacts are stored
        os.makedirs(args.output_dir)

    ###########################################################################
    # Parse attack_method arg
    ###########################################################################
    # check that attack method is supported
    assert (
        args.attack_method in SUPPORTED_ATTACK_METHODS
    ), f"Unsupported attack '{args.attack_method}'"
    print(f"Attack method: {args.attack_method}")

    ###########################################################################
    # Load generations
    ###########################################################################
    print(f"Input dir for this run: {args.input_dir}")
    print(f"Loading previously generated outputs for attacking ...")
    gen_table_meta_path = f"{args.input_dir}/gen_table_meta.json"
    gen_table_path = f"{args.input_dir}/gen_table.jsonl"
    safe_gen_table_path = f"{args.input_dir}/gen_table_safe.jsonl"

    assert os.path.exists(
        gen_table_meta_path
    ), f"failed file check for prev generations metadata json file: {gen_table_meta_path}"
    assert os.path.exists(
        gen_table_path
    ), f"failed file check for prev generations jsonl file: {gen_table_path}"
    assert not os.path.exists(safe_gen_table_path), (
        f"failed for safety bc there is a secondary 'safe' marked file",
        f" in this dir indicating a possible issue with the generation step. ",
    )

    cmdline_args = args.__dict__.copy()
    prev_gen_table_meta = read_json(gen_table_meta_path)

    joined_args = prev_gen_table_meta.copy()
    joined_args.update(cmdline_args)

    # check that the args used to generate the prev generations are the same as
    # the current args, for the intersection of keys
    if not args.overwrite_args:
        for key in prev_gen_table_meta.keys():
            if key in NO_CHECK_ARGS:
                continue
            assert joined_args[key] == prev_gen_table_meta[key], (
                f"failed for safety bc after merging the prev metadata with "
                f"the current cmdline args, values for '{key}' are not the same. "
                f"in metadata: {prev_gen_table_meta[key]}, passed: {cmdline_args[key]}. "
                f"Pass the '--overwrite_args' flag to ignore this check."
            )

    args = argparse.Namespace(**joined_args)
    gen_table = [ex for ex in read_jsonlines(gen_table_path)]
    gen_table_ds = Dataset.from_list(gen_table[: args.limit_rows])

    ###########################################################################
    # Start logging, we wait to do this until after loading the generations
    # so that we can log the args used to generate them unioned with the
    # cmdline args
    ###########################################################################
    # storing slurm info to allow auditing logfiles
    # note this is set after the metadata check to ignore overwriting
    args.SLURM_JOB_ID = os.getenv("SLURM_JOB_ID")
    args.SLURM_ARRAY_JOB_ID = os.getenv("SLURM_ARRAY_JOB_ID")
    args.SLURM_ARRAY_TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")

    if args.wandb:
        # start a new wandb run to track this experiment, will send data to it
        run = wandb.init(
            # set the wandb project where this run will be logged
            project=args.wandb_project,
            entity=args.wandb_entity,
            name=f"{args.run_name}",
            # track hyperparameters and run metadata
            config=args,
            tags=args.wandb_tags,
        )

    ###########################################################################
    # GPT attack
    ###########################################################################

    if args.attack_method == "gpt":
        print("Running GPT attack")
        import openai

        openai.api_key = os.environ["OPENAI_API_KEY"]
        prompt_pool = read_json("utils/prompts.json")["prompt_pool"]
        prompt_pool = {int(k): v for k, v in prompt_pool.items()}

        if args.attack_prompt is None:
            attack_prompt = prompt_pool[args.attack_prompt_id]
            args.attack_prompt = attack_prompt

        print(f"Using attack prompt: {attack_prompt}")

        gpt_attack_partial = partial(
            gpt_attack,
            attack_prompt=attack_prompt,
            args=args,
        )
        # gen_table_attacked_ds = gen_table_ds.map(
        #     gpt_attack_partial, batched=False, num_proc=min(len(gen_table_ds), 16)
        # )
        gen_table_attacked_ds = gen_table_ds.map(gpt_attack_partial, batched=False)

    ###########################################################################
    # DIPPER attack
    ###########################################################################

    elif args.attack_method == "dipper":
        print("Running DIPPER attack")
        print(f"Using lexical diversity: {args.lex}, order diversity: {args.order}")
        gen_table_attacked_ds = dipper_attack(
            gen_table_ds, lex=args.lex, order=args.order, args=args
        )

    ###########################################################################
    # Scramble attack
    ###########################################################################
    elif args.attack_method == "scramble":
        #  if no cp_attack_min_len specified, use args.max_new_tokens
        if args.cp_attack_min_len == 0:
            args.cp_attack_min_len = args.max_new_tokens
        tokenizer = load_tokenizer(args)
        scramble_attack_partial = partial(
            scramble_attack,
            tokenizer=tokenizer,
            args=args,
        )
        gen_table_attacked_ds = gen_table_ds.map(scramble_attack_partial, batched=False)
    ###########################################################################
    # Copy-paste attack
    ###########################################################################
    elif args.attack_method == "copy-paste":
        #  if no cp_attack_min_len specified, use args.max_new_tokens
        if args.cp_attack_min_len == 0:
            args.cp_attack_min_len = args.max_new_tokens

        # NOTE FIXME: the above arg indicates the filter condition by which
        # some rows are skipped/not attacked/NOOP. Since the attacked col
        # is set to the empty string, and length 0, the detection code
        # including the baselines 🤞🏼 will ignore these rows one way or another

        # convert cp_attack_insertion_len to int
        if "%" in args.cp_attack_insertion_len:
            original_len_str = args.cp_attack_insertion_len
            # treat as a percent of 1 minus the length of the source col
            # effectively how much of the source col "remains", accounting for
            # the number of insertions that will be made to total this length
            args.cp_attack_insertion_len = (
                int((int(args.cp_attack_insertion_len[:-1]) / 100) * args.max_new_tokens)
                // args.cp_attack_num_insertions
            )
            # check that this is not more than args.max_new_tokens total
            assert (
                args.cp_attack_insertion_len * args.cp_attack_num_insertions <= args.max_new_tokens
            ) and (
                args.cp_attack_insertion_len * args.cp_attack_num_insertions > 0
            ), f"Invalid attack strength: {original_len_str} for {args.cp_attack_num_insertions} insertions."

            args.cp_attack_effective_attack_percentage = (
                1 - (int(original_len_str[:-1]) / 100)
            ) * 100
            print(
                f"Effective attack percentage is 1-{original_len_str}={args.cp_attack_effective_attack_percentage}% by "
                f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
                f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
            )
        else:
            args.cp_attack_insertion_len = int(args.cp_attack_insertion_len)
            args.cp_attack_effective_attack_percentage = (
                1
                - (
                    (args.cp_attack_insertion_len * args.cp_attack_num_insertions)
                    / args.max_new_tokens
                )
            ) * 100
            print(
                f"Effective attack percentage is {args.cp_attack_effective_attack_percentage}% by "
                f"copying {args.cp_attack_num_insertions} x {args.cp_attack_insertion_len} = {args.cp_attack_num_insertions * args.cp_attack_insertion_len} tokens "
                f"from {args.cp_attack_src_col} to {args.cp_attack_dst_col} where T={args.max_new_tokens}"
            )

        tokenizer = load_tokenizer(args)
        tokenize_for_copy_paste_partial = partial(tokenize_for_copy_paste, tokenizer=tokenizer)
        gen_table_tokd_ds = gen_table_ds.map(tokenize_for_copy_paste_partial, batched=False)

        copy_paste_attack_partial = partial(copy_paste_attack, tokenizer=tokenizer, args=args)
        gen_table_attacked_ds = gen_table_tokd_ds.map(copy_paste_attack_partial, batched=False)
    ###########################################################################
    # Write the final dataset out to disk in jsonl format
    # with the metrics added
    ###########################################################################
    else:
        raise ValueError(f"Invalid attack method: {args.attack_method}")

    # write the metadata file, which is a union of the previous metadata
    # and the current cmdline args
    write_json(args.__dict__, attacked_meta_path, indent=4)

    gen_table_attacked_lst = [ex for ex in gen_table_attacked_ds]
    write_jsonlines(gen_table_attacked_lst, gen_table_attacked_path)

    ###########################################################################
    # Log the data/series to wandb
    ###########################################################################
    # log the metrics to wandb
    if args.wandb:
        # find cols that should be logged in a table
        tabular_column_types = ["string", "bool"]
        tabular_column_names = [
            name
            for name, _ in filter(
                lambda tup: tup[1].dtype in tabular_column_types,
                gen_table_attacked_ds.features.items(),
            )
        ]
        # the rest should be logged as series
        series_column_names = [
            name
            for name, _ in filter(
                lambda tup: tup[1].dtype not in tabular_column_types,
                gen_table_attacked_ds.features.items(),
            )
        ]
        for metric_name in series_column_names:
            # summarize series metrics as mean by default
            wandb.define_metric(metric_name, summary="mean")
        # log the raw series
        for example in tqdm(
            gen_table_attacked_ds.remove_columns(tabular_column_names),
            desc="Logging series metrics to wandb",
        ):
            run.log(example)
        # log the raw tabular data
        # but also include the dataset index as a column
        series_column_names.remove("idx")
        table = wandb.Table(
            dataframe=gen_table_attacked_ds.remove_columns(series_column_names).to_pandas()
        )
        run.log({"output_table": table})

        # finish the wandb run
        run.finish()

    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run evaluation pipeline for watermark detection")
    parser.add_argument(
        "--attack_method",
        type=str,
        choices=SUPPORTED_ATTACK_METHODS,
        default="gpt",
        help="The attack method to use.",
    )
    parser.add_argument(
        "--attack_model_name",
        type=str,
        default="gpt-3.5-turbo",
    )
    parser.add_argument(
        "--attack_temperature",
        type=float,
        default=0.7,
    )
    parser.add_argument(
        "--attack_max_tokens",
        type=int,
        default=1000,
    )
    parser.add_argument(
        "--attack_prompt_id",
        type=int,
        default=4,
    )
    parser.add_argument(
        "--attack_prompt",
        type=str,
        default=None,
        help="Pass in the prompt to use for the attack. Is loaded by id from utils/prompts.json by default.",
    )
    parser.add_argument(
        "--no_wm_attack",
        type=str2bool,
        default=False,
        help="Whether to attack the no_wm_output column when running gpt or dipper.",
    )
    parser.add_argument(
        "--overwrite_args",
        type=str2bool,
        default=False,
        help="Whether to overwrite the shared args in the metadata file with the current, runtime args.",
    )
    parser.add_argument(
        "--wandb",
        type=str2bool,
        default=False,
        help="Whether to log to wandb.",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default="lm-watermarking",
        help="The name of the wandb project.",
    )
    parser.add_argument(
        "--wandb_entity",
        type=str,
        default="<username>",
        help="The wandb entity/user for the project.",
    )
    parser.add_argument(
        "--wandb_tags",
        type=str,
        default="",
        help="The comma separated list of tags to add to the wandb run.",
    )
    parser.add_argument(
        "--run_name",
        type=str,
        default=None,
        help="The unique name for the run.",
    )
    parser.add_argument(
        "--input_dir",
        type=str,
        default="./input",
        help="The directory containing the input files.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help=(
            "The directory in which to write out the dataset after adding the metrics. "
            "If not specified, will use the input_dir. Note, if the output_dir already "
            "contains the metric-enriched file, it will be overwritten :/"
        ),
    )
    parser.add_argument(
        "--overwrite_output_file",
        type=str2bool,
        default=False,
        help="Whether to overwrite the output file if it already exists.",
    )
    parser.add_argument(
        "--limit_rows",
        type=int,
        default=None,
        help="The number of rows to limit the dataset to. Useful for debugging.",
    )
    parser.add_argument(
        "--verbose",
        type=str2bool,
        default=False,
        help="Whether to print verbose output of every attack.",
    )
    parser.add_argument(
        "--lex",
        type=int,
        default=20,
        help="Lexical diversity knob for the paraphrase attack.",
    )
    parser.add_argument(
        "--order",
        type=int,
        default=0,
        help="Order diversity knob for the paraphrase attack.",
    )
    parser.add_argument(
        "--cp_attack_type",
        type=str,
        default="single-single",
        choices=["single-single", "triple-single", "k-t"],
        help="Type of copy-paste attack to be run.",
    )
    parser.add_argument(
        "--cp_attack_min_len",
        type=int,
        default=0,
        help="Minimum length of cols for the copy-paste attack to be run.",
    )
    parser.add_argument(
        "--cp_attack_num_insertions",
        type=int,
        default=3,
        help="Length of the insertion for the copy-paste attack.",
    )
    parser.add_argument(
        "--cp_attack_insertion_len",
        type=str,
        default="20",
        help=(
            f"Length of the insertion for the copy-paste attack. "
            f"Converts to int. Unless expressed as a percentage, "
            f"in which case it refers to what percent of src is copied to dst, "
            f"which is 1-attack strength as a percentage."
        ),
    )
    parser.add_argument(
        "--cp_attack_src_col",
        type=str,
        default="w_wm_output",
        help="Source column for the copy-paste attack.",
    )
    parser.add_argument(
        "--cp_attack_dst_col",
        type=str,
        default="no_wm_output",
        help="Destination column for the copy-paste attack.",
    )
    args = parser.parse_args()

    ###########################################################################
    # Argument validation and conditional setting
    ###########################################################################

    assert args.attack_method, "attack_method must be specified"

    # if no output dir specified, use the input dir
    if args.output_dir is None:
        args.output_dir = args.input_dir

    # check limit_rows
    assert (args.limit_rows is None) or (
        (args.limit_rows > 0) and isinstance(args.limit_rows, int)
    ), "limit_rows must be > 0 or None"

    # split wandb tags
    if args.wandb_tags != "":
        args.wandb_tags = args.wandb_tags.split(",")
    else:
        args.wandb_tags = []

    main(args)
