"""Implement other PRF functions (These all vary only how they generate a single hash from the tokens in the context).

Can be hooked into existing WatermarkLogitsProcessor as modified base class WatermarkBase, see implementation in
extended_watermark_processor.py
"""

# coding=utf-8
# Copyright 2023 Authors of "A Watermark for Large Language Models"
# available at https://arxiv.org/abs/2301.10226
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from itertools import combinations
from functools import cache

# Key properties of a hashing scheme
props = {
    "prf_type": str,  # string name of the underlying PRF mapping multiple token ids to a random seed
    "context_width": int,  # this is h in the paper, how many previous tokens should be considered for each PRF
    "self_salt": bool,  # Use the rules laid in robust-watermarking to use the token itself to seed and possibly reject its own list
    "hash_key": int,  # integer, large prime, used to move seed away from low-entrop bit sequences in PRF chosen above
}


def seeding_scheme_lookup(seeding_scheme: str):
    if not isinstance(seeding_scheme, str):
        raise ValueError("Seeding scheme should be a string summarizing the procedure.")
    if seeding_scheme == "simple_1" or seeding_scheme == "lefthash":
        # Default, simple bigram hash  # alias for ff-additive_prf-1-False-15485863
        prf_type = "additive_prf"
        context_width = 1
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme == "algorithm-3" or seeding_scheme == "selfhash":
        prf_type = "anchored_minhash_prf"
        context_width = 4
        self_salt = True
        hash_key = 15485863
    elif seeding_scheme == "minhash":
        prf_type = "minhash_prf"
        context_width = 4
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme == "skipgram":
        prf_type = "skipgram_prf"
        context_width = 5
        self_salt = False
        hash_key = 15485863
    elif seeding_scheme.startswith("ff"):  # freeform seeding scheme API - only use for experimenting
        # expects strings of the form ff-additive_prf-4-True-hash or ff-additive_prf-5-True (hash key is optional)
        split_scheme = seeding_scheme.split("-")
        prf_type = str(split_scheme[1])
        context_width = int(split_scheme[2])
        self_salt = split_scheme[3] == "True"
        if len(split_scheme) == 5:
            hash_key = int(split_scheme[4])
        else:
            hash_key = 15485863
    else:
        raise ValueError(f"Invalid seeding scheme name {seeding_scheme} given. Try  'simple_1'?")

    assert prf_type in prf_lookup.keys()
    return prf_type, context_width, self_salt, hash_key


def multiplicative_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    return salt_key * input_ids.prod().item()


def additive_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    return salt_key * input_ids.sum().item()


def minfunc_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    # not a great idea for non-random input ids as in text
    return salt_key * input_ids.min().item()


def simple_skip_prf(input_ids: torch.LongTensor, salt_key: int, k=2) -> int:
    # k is the skip distance
    return hashint(salt_key * input_ids[::k]).prod().item()


def skipgram_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    # maximum distance skipgram within context
    return hashint(salt_key * input_ids[0]).item()


def anchored_skipgram_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
    # maximum distance skipgram within context
    return (hashint(salt_key * input_ids[0]) * hashint(salt_key * input_ids[anchor])).item()


def minhash_prf(input_ids: torch.LongTensor, salt_key: int) -> int:
    # slightly less not the greatest idea for non-random input ids as in text
    return hashint(salt_key * input_ids).min().item()


def anchored_minhash_prf(input_ids: torch.LongTensor, salt_key: int, anchor: int = -1) -> int:
    # Anchor to one key to produce a min over pairs again
    return (salt_key * hashint(input_ids) * hashint(input_ids[anchor])).min().item()


def minskipgram_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    # min over all skipgrams in context, k=2 is all pairs
    skipgrams = torch.as_tensor(list(combinations(hashint(salt_key * input_ids), 2)))
    return skipgrams.prod(dim=1).min().item()


def noncomm_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    key = torch.as_tensor(salt_key, dtype=torch.long)
    for entry in input_ids:
        key *= hashint(key * entry)
        key %= 2**32
    return key.item()


def position_prf(input_ids: torch.LongTensor, salt_key: int, k: int = 2) -> int:
    return (salt_key * input_ids * torch.arange(1, len(input_ids) + 1, device=input_ids.device)).sum().item()


prf_lookup = {
    "multiplicative_prf": multiplicative_prf,
    "additive_prf": additive_prf,
    "minfunc_prf": minfunc_prf,
    "simple_skip_prf": simple_skip_prf,
    "skipgram_prf": skipgram_prf,
    "anchored_skipgram_prf": anchored_skipgram_prf,
    "minhash_prf": minhash_prf,
    "anchored_minhash_prf": anchored_minhash_prf,
    "minskipgram_prf": minskipgram_prf,
    "noncomm_prf": noncomm_prf,
    "position_prf": position_prf,
}

# Generate a global permute table once at startup
rng = torch.Generator(device=torch.device("cpu"))
rng.manual_seed(2971215073)  # fib47 is prime
table_size = 1_000_003
fixed_table = torch.randperm(1_000_003, device=torch.device("cpu"), generator=rng)  # actually faster than I thought


def hashint(integer_tensor: torch.LongTensor) -> torch.LongTensor:
    """Sane version, in the end we only need a small permutation table."""
    return fixed_table[integer_tensor.cpu() % table_size] + 1  # minor cheat here, this function always return CPU values


def _hashint_avalanche_tensor(integer_tensor: torch.LongTensor):
    """http://burtleburtle.net/bob/hash/integer.html, ported into pytorch, runs on tensors. Apparently a decent avalanche."""
    i = integer_tensor.to(torch.int32).clone()  # or torch.int16?
    i -= i << 6
    i ^= i >> 17
    i -= i << 9
    i ^= i << 4
    i -= i << 3
    i ^= i << 10
    i ^= i >> 15
    return i.to(torch.long)


@cache
def _hashint_avalanche_int(integer: int):
    """http://burtleburtle.net/bob/hash/integer.html, runs in base python, caches based on access.
    Does this make sense for signed 64bit ints?"""
    i = integer % (2**32)
    i -= i << 6
    i ^= i >> 17
    i -= i << 9
    i ^= i << 4
    i -= i << 3
    i ^= i << 10
    i ^= i >> 15
    return i