# Unwatermark Sampling For Water-Prob-V1

import torch
import ujson as json
import os
import numpy as np
from torch.nn import functional as F
from typing import Union
from transformers import AutoTokenizer
from itertools import product
import sys
import pickle

json_file_path_1 = "../../data/results/prob1/ngram-uw-p1"
json_file_path_2 = "../../data/results/prob1/ngram-uw-p2"
prompt_file_path_1 = "../../data/prompts/ngram-p1.txt"
prompt_file_path_2 = "../../data/prompts/ngram-p2.txt"

json_file_paths = [json_file_path_1, json_file_path_2]

with open(prompt_file_path_1, "r") as f:
    prompt1 = f.readlines()
    prompt1 = "".join(prompt1)

with open(prompt_file_path_2, "r") as f:
    prompt2 = f.readlines()
    prompt2 = "".join(prompt2)

prompts = [prompt1, prompt2]
letters = [f" {chr(i)}" for i in range(65, 91)]
numbers_en = [" zero", " one", " two", " three", " four", " five", " six", " seven", " eight", " nine"]
animal_choice = [" cat", " dog", " tiger", " lion"]

def _sampling(logits, top_k=None, top_p=None, temperature=1.0, device="cuda"):
    assert temperature > 0, "temperature must be positive"
    if top_p is not None:
        assert 0 < top_p <= 1, "top_p must be in the range (0, 1]"

    if isinstance(logits, torch.Tensor):
        _logits = logits.clone()
    else:
        _logits = torch.tensor(logits, device=device)

    _logits /= temperature

    # Apply top-k sampling
    if top_k > 0:
        top_k = min(
            top_k, _logits.size(-1)
        )  # Ensure top_k is not greater than the vocabulary size
        indices_to_remove = _logits < torch.topk(_logits, top_k)[0][..., -1, None]
        _logits[indices_to_remove] = float("-inf")

    # Apply top-p sampling
    if top_p > 0 and top_p < 1:
        sorted_logits, sorted_indices = torch.sort(_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        if sorted_indices_to_remove[..., 1:].any():
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        _logits[indices_to_remove] = float("-inf")

    # Get probability distribution
    probs = F.softmax(_logits, dim=-1)
    sampled_indices = torch.multinomial(probs, num_samples=1)
    return sampled_indices

def sample_batch_uw(logits, batch_size, temperature, top_k, top_p, device):
    
    # context_ids = [None for _ in range(batch_size)]
    # Initialize context_ids as a list of tensors on the GPU
    # Copy logits batch_size times
    logits_batch = logits.expand(batch_size, -1)
    tokens = _sampling(
        logits=logits_batch,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
    ).squeeze(1)
                    
    return tokens.cpu().numpy()

def get_logits(ctx: str, logits: dict, tokenizer: AutoTokenizer):
    cur_logits = logits
    pre_str = "Example12:"
    pre_tokens = tokenizer.encode(pre_str, add_special_tokens=False)
    pre_ctx_tokens = tokenizer.encode(pre_str + ctx, add_special_tokens=False)
    ctx_token = pre_ctx_tokens[len(pre_tokens) :]

    for id in ctx_token:
        cur_logits = cur_logits[(id)]

    assert len(cur_logits.keys()) == 1
    return torch.tensor(cur_logits["logits"], device=device)

def run(combinations, model_name, samples, fill_parts, device):
    num_iters = samples
    batch_size = int(samples // len(fill_parts))

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if model_name in ["opt27b", "opt13b"]:
        vocab_size = 50272
    else:
        vocab_size = tokenizer.vocab_size

    print("Loading remote logits...")
    with open(f"../../data/logits/ngram-p1-logits-{model_name}.pickle", "rb") as f:
        remote_logits_1 = pickle.load(f)

    with open(f"../../data/logits/ngram-p2-logits-{model_name}.pickle", "rb") as f:
        remote_logits_2 = pickle.load(f)

    print("Transporting logits to cuda...")
    remote_logits = [remote_logits_1, remote_logits_2]

    def convert_logits_to_tensor(d):
        for key, value in d.items():
            if isinstance(value, dict):
                convert_logits_to_tensor(value)
            elif key == "logits":
                d[key].to(device)

    convert_logits_to_tensor(remote_logits[0])
    convert_logits_to_tensor(remote_logits[1])

    print("Convert done. Starting sampling...")

    with torch.no_grad():
        for idx in range(2):
            print(f"Processing prompt {idx}...")

            for combination in combinations:
                temperature = combination["temperature"]
                top_p = combination["topp"]
                top_k = combination["topk"]
                
                print(f"Running combination: temperature={temperature}, topp={top_p}, topk={top_k}")
                
                mapping_S_wm = {}
                mapping_S_uw = {}
                
                json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-topp-{top_p}-topk-{top_k}-{samples}-iter-{sample_iter}.json"
                # if already exists, skip
                if os.path.exists(json_file_name):
                    print(f"File {json_file_name} already exists. Skipping...")
                    continue
                
                # A prompt has two windows, fill_part
                for fill_part in fill_parts:
                    print(f"For prompt {idx}, fill part: {fill_part}")
                    print(f"Processing fill part: {fill_part}")
                    assert num_iters % batch_size == 0
                    assert num_iters // batch_size % len(fill_parts) == 0
                    
                    
                    for iter in range(num_iters // batch_size // len(fill_parts)):
                        print(f"Iter: {iter + 1}/{num_iters // batch_size // len(fill_parts)}")
                        logits = get_logits(fill_part, remote_logits[idx], tokenizer)
                        uw_tokens= sample_batch_uw(
                            logits=logits,
                            batch_size=batch_size,
                            temperature=temperature,
                            top_k=top_k,
                            top_p=top_p,
                            device=device,
                        )

                        if fill_part not in mapping_S_uw:
                            mapping_S_uw[fill_part] = {}
                            mapping_S_uw[fill_part]["S_uw"] = [0] * vocab_size
                            
                        for token in uw_tokens:
                            mapping_S_uw[fill_part]["S_uw"][token] += 1

                results = {
                    "watermarked": {str(k): v for k, v in mapping_S_wm.items()},
                    "unwatermarked": {str(k): v for k, v in mapping_S_uw.items()},
                }

                with open(
                    json_file_name,
                    "w",
                ) as json_file:
                    json.dump(results, json_file, separators=(",", ":"))
                
                # Clear CUDA cache to free memory after each combination
                torch.cuda.empty_cache()
                print("Cleared CUDA cache after combination.")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Run script with parameters")
    parser.add_argument("--model_name", type=str, required=True, help="model_name parameter")
    parser.add_argument("--samples", type=int, required=True, help="samples parameter")
    parser.add_argument("--device", type=str, required=True, help="device parameter")
    parser.add_argument("--option", default="experiment", type=str, required=False, help="option parameter")
    parser.add_argument("--model_path", type=str, required=True, help="model_path parameter")
    parser.add_argument("--sample_iter", type=int, required=True, help="sample_iter parameter")
    parser.add_argument("--fill_length", type=int, required=True, help="fill_length parameter")

    args = parser.parse_args()
    
    if args.option == "all":
        combinations = [
        {"temperature": 1.0, "topp": 1.0, "topk": 0},
        {"temperature": 0.8, "topp": 1.0, "topk": 0},
        {"temperature": 0.7, "topp": 1.0, "topk": 0},
        {"temperature": 0.6, "topp": 1.0, "topk": 0},
        {"temperature": 1.2, "topp": 1.0, "topk": 0},
        {"temperature": 1.4, "topp": 1.0, "topk": 0},
        {"temperature": 1.6, "topp": 1.0, "topk": 0},
        {"temperature": 1.0, "topp": 0.7, "topk": 0},
        {"temperature": 1.0, "topp": 0.8, "topk": 0},
        {"temperature": 1.0, "topp": 0.9, "topk": 0},
        {"temperature": 1.0, "topp": 1.0, "topk": 100},
        {"temperature": 1.0, "topp": 1.0, "topk": 200},
        {"temperature": 1.0, "topp": 1.0, "topk": 500},
        {"temperature": 0.8, "topp": 1.0, "topk": 50},
        {"temperature": 0.7, "topp": 1.0, "topk": 50},
        {"temperature": 0.6, "topp": 1.0, "topk": 50},
        {"temperature": 0.8, "topp": 0.7, "topk": 0},
        {"temperature": 0.7, "topp": 0.7, "topk": 0},
        {"temperature": 0.6, "topp": 0.7, "topk": 0},
        {"temperature": 0.6, "topp": 0.7, "topk": 50},
        {"temperature": 1.2, "topp": 0.7, "topk": 50},
        {"temperature": 0.8, "topp": 0.7, "topk": 50}
        ]
    elif args.option == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "top":
        combinations = [
        {"temperature": 1.0, "topp": 0.7, "topk": 0},
        {"temperature": 1.0, "topp": 0.8, "topk": 0},
        {"temperature": 1.0, "topp": 0.9, "topk": 0},
        {"temperature": 1.0, "topp": 1.0, "topk": 100},
        {"temperature": 1.0, "topp": 1.0, "topk": 200},
        {"temperature": 1.0, "topp": 1.0, "topk": 500},
        ]
    elif args.option == "joint":
        combinations = [
        {"temperature": 0.8, "topp": 1.0, "topk": 50},
        {"temperature": 0.7, "topp": 1.0, "topk": 50},
        {"temperature": 0.6, "topp": 1.0, "topk": 50},
        {"temperature": 0.8, "topp": 0.7, "topk": 0},
        {"temperature": 0.7, "topp": 0.7, "topk": 0},
        {"temperature": 0.6, "topp": 0.7, "topk": 0},
        {"temperature": 0.6, "topp": 0.7, "topk": 50},
        {"temperature": 1.2, "topp": 0.7, "topk": 50},
        {"temperature": 0.8, "topp": 0.7, "topk": 50}
        ]
    elif args.option == "temp-most":
        combinations = [
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]
        
    print("Device: ", args.device)
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.device}"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_path = args.model_path
    sample_iter = args.sample_iter
    
    letters = [f" {chr(i)}" for i in range(65, 91)]
    numbers_en = [" zero", " one", " two", " three", " four", " five", " six", " seven", " eight", " nine"]
    animal_choice = [" cat", " dog", " tiger", " lion"]
    combinations_main = ["".join(comb) for comb in product(letters, numbers_en, animal_choice)]

    
    import random
    random.seed(64)
    # Randomly select fill_parts from combinations_main
    fill_parts = random.sample(combinations_main, args.fill_length)
    
    run(
        combinations=combinations,
        model_name=args.model_name,
        samples=args.samples,
        fill_parts=fill_parts,
        device=device
    )
    sys.exit(0)