# DiPmark Sampling Batch for Water-Prob-V1
import torch
import ujson
import os
import numpy as np
from torch.nn import functional as F
from typing import Union
from transformers import AutoTokenizer
import sys
from itertools import product
import hashlib
import random
import pickle

# Logits & Prompts
json_file_path_1 = "../../data/results/prob1/dip-p1"
json_file_path_2 = "../../data/results/prob1/dip-p2"
prompt_file_path_1 = "../../data/prompts/ngram-p1.txt"
prompt_file_path_2 = "../../data/prompts/ngram-p2.txt"

json_file_paths = [json_file_path_1, json_file_path_2]

with open(prompt_file_path_1, "r") as f:
    prompt1 = f.readlines()
    prompt1 = "".join(prompt1)

with open(prompt_file_path_2, "r") as f:
    prompt2 = f.readlines()
    prompt2 = "".join(prompt2)

prompts = [prompt1, prompt2]

# Constants(Fill parts)
letters = [f" {chr(i)}" for i in range(65, 91)]
numbers_en = [
    " zero",
    " one",
    " two",
    " three",
    " four",
    " five",
    " six",
    " seven",
    " eight",
    " nine",
]
animal_choice = [" cat", " dog", " tiger", " lion"]
combinations_main = [
    "".join(comb) for comb in product(letters, numbers_en, animal_choice)
]


def from_random(
    rng: Union[torch.Generator, list[torch.Generator]], vocab_size: int
) -> torch.LongTensor:
    """Generate a permutation from the random number generator."""
    if isinstance(rng, list):
        batch_size = len(rng)
        shuffle = torch.stack(
            [
                torch.randperm(vocab_size, generator=rng[i], device=rng[i].device)
                for i in range(batch_size)
            ]
        )
    else:
        shuffle = torch.randperm(vocab_size, generator=rng, device=rng.device)
    return shuffle


def reweight_logits(
    shuffle: torch.LongTensor, p_logits: torch.FloatTensor, alpha: float
) -> torch.FloatTensor:
    """Reweight the logits using the shuffle and alpha."""
    # Step 1: Unshuffle the indexes
    # Compute the inverse permutation to restore the original order after shuffling.
    unshuffle = torch.argsort(shuffle, dim=-1)

    # Step 2: Shuffle logits
    # Rearrange logits based on the given shuffle permutation.
    s_p_logits = torch.gather(p_logits, -1, shuffle)
    
    # Step 3: Compute log-cumulative-sum-exp
    # Compute the log of the cumulative sum of exponentiated logits for numerical stability.
    # Equation: logcumsumexp(x) = log(exp(x[0]) + exp(x[1]) + ... + exp(x[i]))
    s_log_cumsum = torch.logcumsumexp(s_p_logits, dim=-1)

    # Step 4: Normalize the cumulative sum
    # Normalize the log cumulative sum such that the last element is 0. This ensures the
    # cumulative probability distribution sums to 1.
    s_log_cumsum = s_log_cumsum - s_log_cumsum[..., -1:]
    
    # Step 5: Exponentiate to get the cumulative sum
    # Convert back from log-space to normal space, obtaining cumulative sum (probabilities).
    s_cumsum = torch.exp(s_log_cumsum)
    
    #### Handling the first boundary (α) ####
    
    # Step 6: Softmax the shuffled logits
    # Apply the softmax function to the shuffled logits to get the probability distribution
    s_p = F.softmax(s_p_logits, dim=-1)

    # Step 7: Find the boundary index where cumulative sum exceeds alpha
    # Find the index where the cumulative sum exceeds `alpha`.
    boundary_1 = torch.argmax((s_cumsum > alpha).to(torch.int), dim=-1, keepdim=True)
    
    
    p_boundary_1 = torch.gather(s_p, -1, boundary_1)
    portion_in_right_1 = (torch.gather(s_cumsum, -1, boundary_1) - alpha) / p_boundary_1
    portion_in_right_1 = torch.clamp(portion_in_right_1, 0, 1)
    s_all_portion_in_right_1 = (s_cumsum > alpha).type_as(p_logits)
    s_all_portion_in_right_1.scatter_(-1, boundary_1, portion_in_right_1)

    boundary_2 = torch.argmax(
        (s_cumsum > (1 - alpha)).to(torch.int), dim=-1, keepdim=True
    )
    p_boundary_2 = torch.gather(s_p, -1, boundary_2)
    portion_in_right_2 = (
        torch.gather(s_cumsum, -1, boundary_2) - (1 - alpha)
    ) / p_boundary_2
    portion_in_right_2 = torch.clamp(portion_in_right_2, 0, 1)
    s_all_portion_in_right_2 = (s_cumsum > (1 - alpha)).type_as(p_logits)
    s_all_portion_in_right_2.scatter_(-1, boundary_2, portion_in_right_2)

    s_all_portion_in_right = s_all_portion_in_right_2 / 2 + s_all_portion_in_right_1 / 2
    s_shift_logits = torch.log(s_all_portion_in_right)
    shift_logits = torch.gather(s_shift_logits, -1, unshuffle)

    return p_logits + shift_logits


def _get_rng_seed(context_code: any) -> int:
    """Get the random seed from the given context code and private key."""
    m = hashlib.sha256()
    m.update(context_code)
    m.update(hash_key)

    full_hash = m.digest()
    seed = int.from_bytes(full_hash, "big") % (2**32 - 1)
    return seed


def _extract_context_code(context: torch.Tensor, prefix_length: int) -> torch.Tensor:
    """Extract context code from the given context tensor."""
    if prefix_length == 0:
        return context
    else:
        return context[:, -prefix_length:]

def get_seed_for_cipher(input_ids: torch.Tensor, prefix_length: int):
    """Get the seeds for the cipher using vectorized tensor operations."""
    # Extract the context codes using tensor slicing
    context_codes = _extract_context_code(input_ids, prefix_length)
    
    # Concatenate the tensor slices for hashing
    batch_size = context_codes.size(0)
    seeds = []
    
    # Iterate over batch and compute seeds in parallel
    for i in range(batch_size):
        context_code = context_codes[i].detach().cpu().numpy().tobytes()  # Convert to bytes for hashlib
        seed = _get_rng_seed(context_code)
        seeds.append(seed)
        
    # assert all seed in seeds are equal
    assert all(seed == seeds[0] for seed in seeds)
    
    return seeds


def _apply_watermark(input_ids: torch.LongTensor, scores: torch.FloatTensor, alpha:float, prefix_length: int) -> torch.FloatTensor:
    """Apply watermark to the scores."""
    seeds = get_seed_for_cipher(input_ids, prefix_length=prefix_length)

    rng = [torch.Generator(device=scores.device).manual_seed(seed) for seed in seeds]
    # mask = torch.tensor(mask, device=scores.device)
    shuffle = from_random(rng, scores.size(1))

    reweighted_scores = reweight_logits(shuffle, scores, alpha)

    return reweighted_scores


def _sampling(logits, top_k=None, top_p=None, temperature=1.0):
    """Sampling function"""
    assert temperature > 0, "temperature must be positive"

    _logits = logits / temperature

    # Apply top-k sampling
    if top_k > 0:
        top_k = min(
            top_k, _logits.size(-1)
        )  # Ensure top_k is not greater than the vocabulary size
        indices_to_remove = _logits < torch.topk(_logits, top_k)[0][..., -1, None]
        _logits[indices_to_remove] = float("-inf")

    # Apply top-p sampling
    if top_p > 0 and top_p < 1:
        sorted_logits, sorted_indices = torch.sort(_logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        if sorted_indices_to_remove[..., 1:].any():
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                ..., :-1
            ].clone()
            sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        _logits[indices_to_remove] = float("-inf")

    # Get probability distribution
    probs = F.softmax(_logits, dim=-1)
    sampled_indices = torch.multinomial(probs, num_samples=1)
    return sampled_indices


def dip_sampling(
    input_ids,
    scores,
    alpha,
    prefix_length,
    temperature,
    top_k,
    top_p,
):
    if input_ids.shape[-1] < prefix_length:
        return scores

    reweighted_scores = _apply_watermark(input_ids, scores, alpha, prefix_length)

    # print("Sampling...")
    sampled_indices = _sampling(
        reweighted_scores, top_k=top_k, top_p=top_p, temperature=temperature
    )

    return sampled_indices


def get_logits(ctx: str, logits: dict, tokenizer: AutoTokenizer):
    cur_logits = logits
    pre_str = "Example12:"
    pre_tokens = tokenizer.encode(pre_str, add_special_tokens=False)
    pre_ctx_tokens = tokenizer.encode(pre_str + ctx, add_special_tokens=False)
    ctx_token = pre_ctx_tokens[len(pre_tokens) :]

    for id in ctx_token:
        # print(cur_logits.keys())
        cur_logits = cur_logits[id]

    assert len(cur_logits.keys()) == 1
    return torch.tensor(cur_logits["logits"], device=device)


def sample_batch_wm(
    logits: torch.Tensor,
    batch_size,
    input_ids,
    vocab_size,
    temperature,
    top_k,
    top_p,
    alpha,
    prefix_length,
):
    # input_ids = input_ids.clone().expand(batch_size, -1).to(device)

    # Expand logits to batch_size
    logits_batch = logits.expand(batch_size, -1)

    tokens = dip_sampling(
        input_ids=input_ids,
        scores=logits_batch,
        alpha=alpha,
        prefix_length=prefix_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
    ).squeeze(1)

    return tokens.cpu().numpy()


def run(
    combinations,
    model_name,
    samples,
    alpha,
    prefix_length,
    fill_parts,
    sample_iter,
    device,
):
    num_iters = samples
    assert num_iters % len(fill_parts) == 0
    batch_size = int(samples // len(fill_parts))  # Each fill_part is a batch

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if model_name in ["opt27b", "opt13b"]:
        vocab_size = 50272
    else:
        vocab_size = tokenizer.vocab_size

    print("Loading remote logits...")
    with open(f"../../data/logits/ngram-p1-logits-{model_name}.pickle", "rb") as f:
        remote_logits_1 = pickle.load(f)

    with open(f"../../data/logits/ngram-p2-logits-{model_name}.pickle", "rb") as f:
        remote_logits_2 = pickle.load(f)

    print("Converting remote logits to tensors...")
    remote_logits = [remote_logits_1, remote_logits_2]

    def convert_logits_to_tensor(d):
        for key, value in d.items():
            if isinstance(value, dict):
                convert_logits_to_tensor(value)
            elif key == "logits":
                d[key] = torch.tensor(value, device=device)

    convert_logits_to_tensor(remote_logits[0])
    convert_logits_to_tensor(remote_logits[1])

    print("Convert done. Starting sampling...")

    with torch.no_grad():
        for idx in range(2):
            print(f"Processing prompt {idx}...")

            for combination in combinations:
                temperature = combination["temperature"]
                top_p = combination["topp"]
                top_k = combination["topk"]

                print(
                    f"Running combination: temperature={temperature}, topp={top_p}, topk={top_k}"
                )

                json_file_name = f"{json_file_paths[idx]}-{model_name}-temp-{temperature}-topp-{top_p}-topk-{top_k}-alpha-{alpha}-prefixlen-{prefix_length}-{samples}-{len(fill_parts)}-iter-{sample_iter}.json"
                # if already exists, skip
                if os.path.exists(json_file_name):
                    print(f"File already exists: {json_file_name}")
                    continue

                mapping_S_wm = {}
                mapping_S_uw = {}

                # Each prompt has two fill_parts
                for fill_part in fill_parts:
                    print(f"For prompt {idx}, fill part: {fill_part}")
                    print(f"Processing fill part: {fill_part}")
                    input_ids = tokenizer.encode(
                        prompts[idx] + fill_part, return_tensors="pt"
                    ).to(device)
                    input_ids = input_ids.repeat(batch_size, 1)
                    assert num_iters % batch_size == 0
                    assert num_iters // batch_size % len(fill_parts) == 0
                    for iter in range(num_iters // batch_size // len(fill_parts)):
                        print(f"Iter: {iter + 1}")

                        logits = get_logits(fill_part, remote_logits[idx], tokenizer)
                        wm_tokens = sample_batch_wm(
                            logits=logits,
                            batch_size=batch_size,
                            input_ids=input_ids,
                            vocab_size=vocab_size,
                            temperature=temperature,
                            top_k=top_k,
                            top_p=top_p,
                            alpha=alpha,
                            prefix_length=prefix_length,
                        )

                        if fill_part not in mapping_S_wm:
                            mapping_S_wm[fill_part] = {}
                            mapping_S_wm[fill_part]["S_wm"] = [0] * vocab_size

                        for token in wm_tokens:
                            mapping_S_wm[fill_part]["S_wm"][token] += 1

                results = {
                    "watermarked": {str(k): v for k, v in mapping_S_wm.items()},
                    "unwatermarked": {str(k): v for k, v in mapping_S_uw.items()},
                }

                with open(
                    json_file_name,
                    "w",
                ) as json_file:
                    ujson.dump(results, json_file, separators=(",", ":"))

                # Clear CUDA cache to free memory after each combination
                torch.cuda.empty_cache()
                print("Cleared CUDA cache after combination.")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="DIP Sampling Batch for Water-Prob-V1")
    parser.add_argument("--model_name", type=str, required=True, help="model_name parameter")
    parser.add_argument("--samples", type=int, required=True, help="samples parameter")
    parser.add_argument("--alpha", type=float, required=True, help="alpha parameter")
    parser.add_argument(
        "--prefix_length", type=int, required=True, help="prefix_length parameter"
    )
    parser.add_argument("--device", type=int, required=True, help="device parameter")
    parser.add_argument("--option", default="experiment", type=str, required=False, help="top_k parameter")
    parser.add_argument(
        "--fill_length", type=int, required=True, help="fill_length parameter"
    )
    parser.add_argument("--model_path", type=str, required=True, help="model_path parameter")
    parser.add_argument(
        "--sample_iter", type=int, required=False, help="sample_iter parameter"
    )

    args = parser.parse_args()

    model_path = args.model_path

    if args.option == "all":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.6, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 0.7, "topk": 0},
            {"temperature": 1.0, "topp": 0.8, "topk": 0},
            {"temperature": 1.0, "topp": 0.9, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 100},
            {"temperature": 1.0, "topp": 1.0, "topk": 200},
            {"temperature": 1.0, "topp": 1.0, "topk": 500},
            {"temperature": 0.8, "topp": 1.0, "topk": 50},
            {"temperature": 0.7, "topp": 1.0, "topk": 50},
            {"temperature": 0.6, "topp": 1.0, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 0},
            {"temperature": 0.7, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 50},
            {"temperature": 1.2, "topp": 0.7, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 50},
        ]
    elif args.option == "temp":
        combinations = [
            {"temperature": 1.5, "topp": 1.0, "topk": 0},
            {"temperature": 1.4, "topp": 1.0, "topk": 0},
            {"temperature": 1.3, "topp": 1.0, "topk": 0},
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
            {"temperature": 0.7, "topp": 1.0, "topk": 0},
            {"temperature": 0.6, "topp": 1.0, "topk": 0},
            {"temperature": 0.5, "topp": 1.0, "topk": 0},
            {"temperature": 0.4, "topp": 1.0, "topk": 0},
            {"temperature": 0.3, "topp": 1.0, "topk": 0},
            {"temperature": 0.2, "topp": 1.0, "topk": 0},
            {"temperature": 0.1, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "joint":
        combinations = [
            {"temperature": 0.8, "topp": 1.0, "topk": 50},
            {"temperature": 0.7, "topp": 1.0, "topk": 50},
            {"temperature": 0.6, "topp": 1.0, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 0},
            {"temperature": 0.7, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 0},
            {"temperature": 0.6, "topp": 0.7, "topk": 50},
            {"temperature": 1.2, "topp": 0.7, "topk": 50},
            {"temperature": 0.8, "topp": 0.7, "topk": 50},
        ]
    elif args.option == "temp-most":
        combinations = [
            {"temperature": 1.2, "topp": 1.0, "topk": 0},
            {"temperature": 1.1, "topp": 1.0, "topk": 0},
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
            {"temperature": 0.9, "topp": 1.0, "topk": 0},
            {"temperature": 0.8, "topp": 1.0, "topk": 0},
        ]
    elif args.option == "experiment":
        combinations = [
            {"temperature": 1.0, "topp": 1.0, "topk": 0},
        ]
    print("Device: ", args.device)
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.device}"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # DIP Config
    hash_seed = 42
    random.seed(hash_seed)
    hash_key = random.getrandbits(1024).to_bytes(128, "big")

    # Random sample fill_parts
    import random

    random.seed(64)
    fill_parts = random.sample(combinations_main, args.fill_length)

    print("Fill parts: ")
    print(fill_parts)

    run(
        combinations=combinations,
        model_name=args.model_name,
        samples=args.samples,
        alpha=args.alpha,
        prefix_length=args.prefix_length,
        fill_parts=fill_parts,
        sample_iter=args.sample_iter,
        device=device,
    )
    sys.exit(0)
