import torch


def single_insertion(
    attack_len,
    min_token_count,
    tokenized_no_wm_output,  # dst
    tokenized_w_wm_output,  # src
):
    top_insert_loc = min_token_count - attack_len
    rand_insert_locs = torch.randint(low=0, high=top_insert_loc, size=(2,))

    # tokenized_no_wm_output_cloned = torch.clone(tokenized_no_wm_output) # used to be tensor
    tokenized_no_wm_output_cloned = torch.tensor(tokenized_no_wm_output)
    tokenized_w_wm_output = torch.tensor(tokenized_w_wm_output)

    tokenized_no_wm_output_cloned[
        rand_insert_locs[0].item() : rand_insert_locs[0].item() + attack_len
    ] = tokenized_w_wm_output[rand_insert_locs[1].item() : rand_insert_locs[1].item() + attack_len]

    return tokenized_no_wm_output_cloned


def triple_insertion_single_len(
    attack_len,
    min_token_count,
    tokenized_no_wm_output,  # dst
    tokenized_w_wm_output,  # src
):
    tmp_attack_lens = (attack_len, attack_len, attack_len)

    while True:
        rand_insert_locs = torch.randint(low=0, high=min_token_count, size=(len(tmp_attack_lens),))
        _, indices = torch.sort(rand_insert_locs)

        if (
            rand_insert_locs[indices[0]] + attack_len <= rand_insert_locs[indices[1]]
            and rand_insert_locs[indices[1]] + attack_len <= rand_insert_locs[indices[2]]
            and rand_insert_locs[indices[2]] + attack_len <= min_token_count
        ):
            break

    # replace watermarked sections into unwatermarked ones
    # tokenized_no_wm_output_cloned = torch.clone(tokenized_no_wm_output) # used to be tensor
    tokenized_no_wm_output_cloned = torch.tensor(tokenized_no_wm_output)
    tokenized_w_wm_output = torch.tensor(tokenized_w_wm_output)

    for i in range(len(tmp_attack_lens)):
        start_idx = rand_insert_locs[indices[i]]
        end_idx = rand_insert_locs[indices[i]] + attack_len

        tokenized_no_wm_output_cloned[start_idx:end_idx] = tokenized_w_wm_output[start_idx:end_idx]

    return tokenized_no_wm_output_cloned


def k_insertion_t_len(
    num_insertions,
    insertion_len,
    min_token_count,
    tokenized_dst_output,  # dst
    tokenized_src_output,  # src
    verbose=False,
):
    insertion_lengths = [insertion_len] * num_insertions

    # these aren't save to rely on indiv, need to use the min of both
    # dst_length = len(tokenized_dst_output)
    # src_length = len(tokenized_src_output) # not needed, on account of considering only min_token_count
    # as the max allowed index

    while True:
        rand_insert_locs = torch.randint(
            low=0, high=min_token_count, size=(len(insertion_lengths),)
        )
        _, indices = torch.sort(rand_insert_locs)

        if verbose:
            print(
                f"indices: {[rand_insert_locs[indices[i]] for i in range(len(insertion_lengths))]}"
            )
            print(
                f"gaps: {[rand_insert_locs[indices[i + 1]] - rand_insert_locs[indices[i]] for i in range(len(insertion_lengths) - 1)] + [min_token_count - rand_insert_locs[indices[-1]]]}"
            )

        # check for overlap condition for all insertions
        overlap = False
        for i in range(len(insertion_lengths) - 1):
            if (
                rand_insert_locs[indices[i]] + insertion_lengths[indices[i]]
                > rand_insert_locs[indices[i + 1]]
            ):
                overlap = True
                break

        if (
            not overlap
            and rand_insert_locs[indices[-1]] + insertion_lengths[indices[-1]] < min_token_count
        ):
            break

    # replace watermarked sections into unwatermarked ones

    tokenized_dst_output_cloned = torch.tensor(tokenized_dst_output)
    tokenized_src_output = torch.tensor(tokenized_src_output)

    for i in range(len(insertion_lengths)):
        start_idx = rand_insert_locs[indices[i]]
        end_idx = rand_insert_locs[indices[i]] + insertion_lengths[indices[i]]

        tokenized_dst_output_cloned[start_idx:end_idx] = tokenized_src_output[start_idx:end_idx]

    return tokenized_dst_output_cloned

