# KGW-Fix Sampling Batch for Prob 2(5-gram)

import torch
import ujson
import os
import numpy as np
from torch.nn import functional as F
from typing import Union
from transformers import AutoTokenizer
import sys
from mersenne import MersenneRNG
import random
from itertools import product
import pickle
import numpy as np

json_file_path_1 = "../../data/results/prob2/kgwfix-5gram-p1"
json_file_path_2 = "../../data/results/prob2/kgwfix-5gram-p2"
prompt_file_path_1 = "../../data/prompts/5gram-p1.txt"
prompt_file_path_2 = "../../data/prompts/5gram-p2.txt"

json_file_paths = [json_file_path_1, json_file_path_2]

with open(prompt_file_path_1, "r") as f:
    prompt1 = f.readlines()
    prompt1 = "".join(prompt1)

with open(prompt_file_path_2, "r") as f:
    prompt2 = f.readlines()
    prompt2 = "".join(prompt2)

prompts = [prompt1, prompt2]


letters = [f" {chr(i)}" for i in range(65, 91)]
numbers_en = [
    " zero",
    " one",
    " two",
    " three",
    " four",
    " five",
    " six",
    " seven",
    " eight",
    " nine",
]
animal_choice = [" cat", " dog", " tiger", " lion"]
combinations_main = ["".join(comb) for comb in product(letters, numbers_en, animal_choice)]
# combinations_main = [item + ' |' for item in combinations_tmp]
   

hash_key = 15485863


def _f_time(input_ids, prefix_length, prf, vocab_size):
    batch_size, sequence_length = input_ids.shape
    time_result = torch.ones(batch_size, device=input_ids.device)

    for i in range(prefix_length):
        time_result *= input_ids[:, -1 - i].float() # time_result's shape is (batch_size,)

    # prf's shape: (batch_size, vocab_size)
    indices = (time_result.long() % vocab_size).unsqueeze(1)  # (batch_size, 1)

    result = torch.gather(prf, 1, indices).squeeze(1)  # (batch_size,)
    return result


def _get_greenlist_ids_left(input_ids, gamma, prf, vocab_size, prefix_length, keys, indicators):
    time_results = _f_time(
        input_ids, prefix_length=prefix_length, prf=prf, vocab_size=vocab_size
    ) # (batch_size,)
    seeds = ((keys * time_results) % vocab_size).to(device) # (batch_size,)

    greenlist_size = int(vocab_size * gamma)
    rng_cuda = torch.Generator(device=device)

    vocab_permutations = torch.stack(
        [
            torch.randperm(
                vocab_size, device=device, generator=rng_cuda.manual_seed(seed.item())
            )
            for seed in seeds
        ],
        dim=0,
    )

    # Create the greenlist IDs conditionally based on indicators
    # If indicators[i] == 1, take the first greenlist_size elements
    # If indicators[i] == 0, take the remaining elements from greenlist_size onward
    greenlist_ids = torch.where(
        indicators.unsqueeze(1) == 1,  # Condition: indicators[i] == 1
        vocab_permutations[:, :greenlist_size],  # If True, take first greenlist_size elements
        vocab_permutations[:, greenlist_size:],  # If False, take elements from greenlist_size onward
    )
    return greenlist_ids


def _calc_greenlist_mask(scores, greenlist_token_ids):
    batch_size, vocab_size = scores.shape
    green_tokens_mask = torch.zeros(
        batch_size, vocab_size, device=scores.device, dtype=torch.bool
    )
    green_tokens_mask.scatter_(1, greenlist_token_ids, True)
    return green_tokens_mask


def _bias_greenlist_logits(scores, greenlist_mask, greenlist_bias):
    _scores = scores.clone()
    _scores[greenlist_mask] = scores[greenlist_mask] + greenlist_bias
    return _scores


def _sampling(logits, top_k=None, top_p=None, temperature=1.0):
    assert temperature > 0, "temperature must be a positive number"

    _logits = logits / temperature

    # Apply top-k sampling
    if top_k > 0:
        top_k = min(
            top_k, _logits.size(-1)
        )  # Ensure top_k is not greater than the vocabulary size
        indices_to_remove = _logits < torch.topk(_logits, top_k)[0][..., -1, None]
        _logits[indices_to_remove] = float("-inf")

    # Apply top-p sampling
    if top_p > 0 and top_p < 1:
        sorted_logits, sorted_indices = torch.sort(_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        if sorted_indices_to_remove[..., 1:].any():
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        _logits[indices_to_remove] = float("-inf")

    # Get probability distribution
    probs = F.softmax(_logits, dim=-1)
    print(f"shape of probs: {probs.shape}")
    sampled_indices = torch.multinomial(probs, num_samples=1)
    return sampled_indices


def kgw_sampling(
    input_ids,
    scores,
    gamma,
    delta,
    prefix_length,
    keys,
    indicators,
    vocab_size,
    temperature,
    top_k,
    top_p,
    prf, # (active_indices, vocab_size)
):
    if input_ids.shape[-1] < prefix_length:
        return scores

    # print("Getting greenlist ids...")
    batched_greenlist_ids = _get_greenlist_ids_left(
        input_ids, gamma, prf, vocab_size, prefix_length,keys, indicators=indicators
    )

    # print("Calculating greenlist mask...")
    green_tokens_mask = _calc_greenlist_mask(
        scores, greenlist_token_ids=batched_greenlist_ids
    )

    # print("Biasing greenlist logits...")
    scores = _bias_greenlist_logits(
        scores, greenlist_mask=green_tokens_mask, greenlist_bias=delta
    )

    # print("Sampling...")
    sampled_indices = _sampling(
        scores, top_k=top_k, top_p=top_p, temperature=temperature
    )

    return sampled_indices


def get_logits(ctx: str, logits: dict, tokenizer: AutoTokenizer):
    cur_logits = logits
    pre_str = "Example12:"
    pre_tokens = tokenizer.encode(pre_str, add_special_tokens=False)
    pre_ctx_tokens = tokenizer.encode(pre_str + ctx, add_special_tokens=False)
    ctx_token = pre_ctx_tokens[len(pre_tokens) :]

    for id in ctx_token:
        cur_logits = cur_logits[str(id)]

    assert len(cur_logits.keys()) == 1 or len(cur_logits.keys()) == 2
    return torch.tensor(cur_logits["logits"], device=device)

def sample_batch_wm(
    logits, 
    batch_size,
    input_ids,
    temperature, 
    top_k, 
    top_p,
    gamma,
    delta,
    prefix_length, 
    keys, # (batch_size,)
    indicators, # (batch_size,)
):
    vocab_size = logits["logits"].shape[-1]
    # xi = keys.clone()
    cur_logits_batch = [logits] * batch_size
    active = torch.ones(batch_size, dtype=torch.bool, device=device)
    token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=device)
    context_ids = [None for _ in range(batch_size)]
    
    # Generate batch of random permutations based on keys
    # rng.manual_seed(hash_key)
    # prf = torch.randperm(vocab_size, device=device, generator=rng)
    prf = torch.stack(
        [
            torch.randperm(vocab_size, device=device, generator=rng.manual_seed(keys[i].item()))
            for i in range(batch_size)
        ], 
        dim=0
    ) # (batch_size, vocab_size)

    while active.any():
        active_indices = torch.nonzero(active).squeeze(1)
        logits_batch = torch.stack(
            [
                (cur_logits_batch[i]["logits"]).squeeze(0).to(device)
                for i in active_indices
            ]
        )
        tokens = kgw_sampling(
            input_ids=input_ids[active_indices],
            scores=logits_batch,
            gamma=gamma,
            delta=delta,
            prefix_length=prefix_length,
            keys=keys[active_indices],
            indicators=indicators[active_indices],
            vocab_size=logits_batch.shape[-1],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            prf=prf[active_indices],
        ).squeeze(1)

        # Update context_ids and cur_logits_batch
        token_idx = 0
        for i in range(batch_size):
            if not active[i]:
                continue

            token_id = tokens[token_idx].item()
            token_idx += 1

            if context_ids[i] is None:
                context_ids[i] = [token_id]
            else:
                context_ids[i].append(token_id)
                
            # Dynamic update input_ids, append new sampled token, and remove the first token
            input_ids[i] = torch.cat(
                [input_ids[i][1:], torch.tensor([token_id], device=input_ids.device)]
            )

            # token_id_str = str(token_id)
            if token_id in cur_logits_batch[i]:
                cur_logits_batch[i] = cur_logits_batch[i][token_id]
            else:
                if len(cur_logits_batch[i]) == 1 and "logits" in cur_logits_batch[i]:
                    # Current only has 'logits' item, which means this is the last token of a legal prefix
                    token_ids[i] = token_id
                    active[i] = False
                else:
                    # Not in legal sampling list, marked as completed and illegal
                    token_ids[i] = -1
                    active[i] = False

    return token_ids.cpu().numpy(), context_ids

def run(
    combinations, 
    model_name, 
    model_path,
    samples, 
    gamma, 
    delta, 
    prefix_length, 
    keylen, 
    device
):
    num_iters = samples
    batch_size = 5000  # Reduced batch size for better GPU memory management

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if model_name in ["opt27b", "opt13b"]: # ATTENTION: vocab size for opt models
        vocab_size = 50272
    else:
        vocab_size = tokenizer.vocab_size
        
    ## KGW-Edit keylist
    mersenne_rng = MersenneRNG(seed=seed)
    if keylen != 0:
        key_list = np.array([int(mersenne_rng.randint() * 10e8) for _ in range(0, keylen)])
    else:
        key_list = np.array([int(mersenne_rng.randint() * 10e8)])
    
    print("Loading remote logits...")
    with open(f"../../data/logits/5gram-p1-logits-{model_name}.pickle", "rb") as f:
        remote_logits_1 = pickle.load(f)

    with open(f"../../data/logits/5gram-p2-logits-{model_name}.pickle", "rb") as f:
        remote_logits_2 = pickle.load(f)

    print("Converting remote logits to tensors...")
    remote_logits = [remote_logits_1, remote_logits_2]

    def convert_logits_to_tensor(d):
        for key, value in d.items():
            if isinstance(value, dict):
                convert_logits_to_tensor(value)
            elif key == "logits":
                d[key].to(device)

    convert_logits_to_tensor(remote_logits[0])
    convert_logits_to_tensor(remote_logits[1])

    print("Convert done. Starting sampling...")

    with torch.no_grad():
        for idx in range(2):

            print(f"Processing prompt {idx}...")

            for combination in combinations:
                temperature = combination["temperature"]
                top_p = combination["topp"]
                top_k = combination["topk"]

                print(
                    f"Running combination: temperature={temperature}, topp={top_p}, topk={top_k}"
                )
                
                json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-topp-{top_p}-topk-{top_k}-gamma-{gamma}-delta-{delta}-prefixlen-{prefix_length}-{samples}-{keylen}-prob2-iter-{sample_iter}.json"
                
                # if already exists, skip
                if os.path.exists(json_file_name):
                    print(f"File already exists, skipping...")
                    continue

                mapping_S_wm = {}
                mapping_S_uw = {}

                input_ids = tokenizer.encode(prompts[idx], return_tensors="pt").to(device)
                input_ids = input_ids.repeat(batch_size, 1)
                for iter in range(num_iters // batch_size):
                    print(f"Iter: {iter + 1}/{num_iters // batch_size}")
                    # Randomly sample batch_size keys from key_list
                    sampled_keys = np.random.choice(key_list, size=batch_size, replace=True)

                    # Convert sampled_keys to torch.Tensor
                    selected_keys = torch.tensor(sampled_keys).to(device)

                    # random generate a [0,1] binary list of batch size
                    indicator_list = np.random.randint(0, 2, size=batch_size)

                    # Convert selected_indicators to a torch.Tensor and move it to the appropriate device
                    selected_indicators_tensor = torch.tensor(indicator_list).to(device)
                    if keylen == 0:
                        selected_indicators_tensor = torch.ones(batch_size).to(device)
                    
                    # Sample batch_size watermarked tokens
                    wm_tokens, wm_contexts = sample_batch_wm(
                        logits=remote_logits[idx],
                        batch_size=batch_size,
                        input_ids=input_ids,
                        temperature=temperature,
                        top_k=top_k,
                        top_p=top_p,
                        gamma=gamma,
                        delta=delta,
                        prefix_length=prefix_length,
                        keys=selected_keys,
                        indicators=selected_indicators_tensor,
                    )

                    wm_valid_indices = np.where(wm_tokens != -1)[0]
                    wm_valid_contexts = [wm_contexts[i] for i in wm_valid_indices]
                    valid_wm_tokens = wm_tokens[wm_valid_indices]
                    valid_keys = selected_keys[wm_valid_indices].cpu().tolist()
                    valid_indicators = selected_indicators_tensor[wm_valid_indices].cpu().tolist()

                    for i, ctx in enumerate(wm_valid_contexts):
                        # context_str = f' {tokenizer.decode(ctx).rsplit("|")[0].strip()}'
                        # token = valid_wm_tokens[i]
                        sample_key = valid_keys[i]
                        sample_indicator = valid_indicators[i]
                        key_content = f"{sample_key}_{sample_indicator}"
                        decoded_text = f'{tokenizer.decode(ctx).strip()}'
                        parts = decoded_text.split(' ', 5)
                        context_str = f" {' '.join(parts[:5])}"
                        if i % 1000 == 0:
                            print(f"Context: {context_str}")
                        token = valid_wm_tokens[i]

                        if context_str not in mapping_S_wm:
                            mapping_S_wm[context_str] = {}
                            mapping_S_wm[context_str]["S_wm"] = [0] * vocab_size
                            mapping_S_wm[context_str]["key"] = {}
                        mapping_S_wm[context_str]["S_wm"][token] += 1
                        if key_content not in mapping_S_wm[context_str]["key"]:
                            print(f"Content: {context_str} Key not found: {key_content}")
                            mapping_S_wm[context_str]["key"][key_content] = 1
                        else:
                            mapping_S_wm[context_str]["key"][key_content] += 1

                results = {
                    "watermarked": {str(k): v for k, v in mapping_S_wm.items()},
                    "unwatermarked": {str(k): v for k, v in mapping_S_uw.items()},
                }

                with open(
                    json_file_name,
                    "w",
                ) as json_file:
                    print("Writing to json file...")
                    ujson.dump(results, json_file, separators=(",", ":"))
                    print("Done writing to json file.")

                # Clear CUDA cache to free memory after each combination
                torch.cuda.empty_cache()
                print("Cleared CUDA cache after combination.")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run script with parameters")
    parser.add_argument("--model_name", default="llama-2-7b-hf", type=str, required=False, help="model_name parameters")
    parser.add_argument("--model_path", type=str, required=True, help="model_path parameters")
    parser.add_argument("--samples", type=int, required=True, help="samples parameters")
    parser.add_argument("--gamma", default=0.5, type=float, required=False, help="gamma parameters")
    parser.add_argument("--delta", default=2.0, type=float, required=False, help="delta parameters")
    parser.add_argument(
        "--prefix_length", default=4, type=int, required=False, help="prefix_length parameters"
    )
    parser.add_argument("--device", type=int, required=True, help="device parameters")
    parser.add_argument("--option", default="experiment", type=str, required=False, help="option parameters")
    parser.add_argument("--keylen", type=int, required=True, help="keylen parameters")
    parser.add_argument("--sample_iter", type=int, required=True, help="sample_iter parameters")

    args = parser.parse_args()
    
    sample_iter = args.sample_iter

    if args.option == "all":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 0.7, "topk": 0},
            {"temperature": 1.0, "topp": 0.8, "topk": 0},
            {"temperature": 1.0, "topp": 0.9, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 100},
            {"temperature": 1.0, "topp": 1.0, "topk": 200},
            {"temperature": 1.0, "topp": 1.0, "topk": 500},
            {"temperature": 0.8, "topp": 1.0, "topk": 50},
            {"temperature": 0.7, "topp": 1.0, "topk": 50},
            {"temperature": 0.6, "topp": 1.0, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 0},
            {"temperature": 0.7, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 50},
            {"temperature": 1.2, "topp": 0.7, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 50},
        ]
    elif args.option == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "top":
        combinations = [
            {"temperature": 1.0, "topp": 0.7, "topk": 0},
            {"temperature": 1.0, "topp": 0.8, "topk": 0},
            {"temperature": 1.0, "topp": 0.9, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 100},
            {"temperature": 1.0, "topp": 1.0, "topk": 200},
            {"temperature": 1.0, "topp": 1.0, "topk": 500},
        ]
    elif args.option == "joint":
        combinations = [
            {"temperature": 0.8, "topp": 1.0, "topk": 50},
            {"temperature": 0.7, "topp": 1.0, "topk": 50},
            {"temperature": 0.6, "topp": 1.0, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 0},
            {"temperature": 0.7, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 50},
            {"temperature": 1.2, "topp": 0.7, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 50},
        ]
    elif args.option == "temp-most":
        combinations = [
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]

    print("Device: ", args.device)
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.device}"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    rng = torch.Generator(device=device)
    rng.manual_seed(hash_key)
    
    # MersenneRNG's seed
    seed = 42

    run(
        combinations=combinations,
        model_name=args.model_name,
        model_path=args.model_path,
        samples=args.samples,
        gamma=args.gamma,
        delta=args.delta,
        prefix_length=args.prefix_length,
        keylen=args.keylen,
        device=device,
    )
    sys.exit(0)