"""
Sampling methods for autoregressive and speculative sampling

Most of the code for speculative sampling is dervied from this repo: 
https://github.com/feifeibear/LLMSpeculativeSampling


Other notes:
- KV Cache is enabled by default, as that is standard practice
- We time the prefill and generation stages separately, as prefill
  can be compute bound as prompt/input tokens are processed in parallel,
  however, the generation stage is strictly memory bound due to 
  text generation and sequential dependency of sampling new tokens.
"""

import csv
import time
import torch
from torch.nn import functional as F
from src.utils import sample, norm_logits, max_fn
from src.kvcache import KVCacheModel
from src.utils import touch, device, torch_timer


def autoregressive_sampling(
    input_ids, #(batch_size, sequence_length)
    tokenizer,
    N,
    temperature,
    model=None,
    top_k=0,
    top_p=0,
    logfile_name=None,
    **kwargs,
):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    kv_cache = None

    ret_arg_handle = {}

    prefill_shape = input_ids.shape[1]
    prefill_start = 0
    prefill_end = 0
    gen_end = 0
    gen_start = 0

    while n < T:
        if kv_cache:
            # generation stage
            outputs = model(input_ids[:, -1:], use_cache=True, past_key_values=kv_cache)

        else:
            # prefill stage
            prefill_start = torch_timer()

            outputs = model(input_ids, use_cache=True)
            # outputs.logits (batch_size, sequence_length, vocab_size) 表示生成的token在词汇表中每个词的概率分布
            # outputs.past_key_values: list of num_layers tensors of (2, batch_size, num_heads, seq_length, head_dim) 

            prefill_end = torch_timer()
            gen_start = prefill_end

        kv_cache = outputs.past_key_values
        logits = outputs.logits[::, -1, :]
        last_p = norm_logits(logits[-1:, :], temperature)

        next_token_id = sample(last_p)

        input_ids = torch.cat((input_ids, next_token_id), dim=-1)

        if next_token_id == tokenizer.eos_token_id:
            print("EOS found, exiting early")
            break

        n += 1

    gen_end = torch_timer()

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
    }

    return input_ids, ret_arg_handle


def speculative_sampling(
    input_ids,
    tokenizer,
    N,
    temperature,
    draft_model=None,
    target_model=None,
    gamma=4,
    random_seed=False,
    record_speculation=False,
    record_empirical_upper_bound=False,
    logfile_name=None,
    **kwargs,
):
    n = input_ids.shape[1] # input_ids: (batch_size, sequence_length)
    T = input_ids.shape[1] + N
    draft_model_cache = KVCacheModel(draft_model, temperature)
    target_model_cache = KVCacheModel(target_model, temperature)

    ret_arg_handle = {}

    prefill_shape = input_ids.shape[1]

    if record_speculation:
        history_ids = torch.zeros((1, 2 * T), dtype=torch.long)
        draft_sample_count = 0
        target_sample_count = 0
        thrown_away_count = 0
        accepted_count = 0
        resampled_count = 0

    while input_ids.shape[1] < T:
        prefix_len = input_ids.shape[1]

        # Prefill the KVCache and record the time
        if draft_model_cache.prefill_cache_time == -1:
            prefill_start = torch_timer()

            x = draft_model_cache.generate(input_ids, gamma)
            _ = target_model_cache.generate(x, 1)

            prefill_end = target_model_cache.prefill_cache_time
            gen_start = prefill_end
        else:
            x = draft_model_cache.generate(input_ids, gamma)
            _ = target_model_cache.generate(x, 1)

        if record_speculation:
            draft_sample_count += gamma
            target_sample_count += gamma + 1

        n = prefix_len + gamma - 1

        # Determine whether to accept or reject the tokens
        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device=device)

            j = x[:, prefix_len + i]

            if (
                r
                > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) #  (batch_size, sequence_length, vocab_size) 表示生成的token在词汇表中每个词的归一化后的概率分布
                / (draft_model_cache._prob_history[:, prefix_len + i - 1, j])
                and not record_empirical_upper_bound
            ):
                # reject
                n = prefix_len + i - 1

                if record_speculation:
                    thrown_away_count += gamma - i # 只针对draft model中被丢弃的tokens
                break

            if record_speculation:
                accepted_count += 1 # 也只针对draft model中的gamma个token，target model 中多余的最后一个token不算在内
                # tag accpeted tokens as 1 in history_ids
                history_ids[0, prefix_len + i - 1] = 1

        # Append input_ids up to the accepted tokens
        input_ids = x[:, : n + 1]

        # Rollback the target_model
        draft_model_cache.rollback(n + 1) # 左闭右开区间

        if n < prefix_len + gamma - 1 and not record_empirical_upper_bound:
            # Reject someone, sample from the pos n
            q_minus_p = max_fn(
                target_model_cache._prob_history[:, n, :]
                - draft_model_cache._prob_history[:, n, :]
            )
            # prob = target_model_cache._prob_history[0, n, t] - draft_model_cache._prob_history[0, n, t]
            t = sample(q_minus_p)
            target_model_cache.rollback(n + 1)

            if record_speculation:
                # tag resampled tokens as 2 in history_ids
                history_ids[0, n] = 2

        else:
            # All approx model decoding accepted
            t = sample(target_model_cache._prob_history[:, -1, :])  # 这里好像重复sample了一次，不知道为什么
            if record_speculation:
                target_sample_count += 1
                # target resampled as 3
                history_ids[0, n] = 3
            target_model_cache.rollback(n + 2)

        if record_speculation:
            resampled_count += 1 # 上面不管history_ids[n] 被记作了2还是3，都对resampled_count ++; =2时，重新sample下一个位置； =3时，用target重新sample下一个位置

        input_ids = torch.cat((input_ids, t), dim=1)

        # exit if eos was previously seen or if the current token is eos
        if torch.any(input_ids[:-gamma] == tokenizer.eos_token_id):
            print("EOS found, exiting early")
            break

    gen_end = torch_timer()

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time
    acceptance_rate = accepted_count / draft_sample_count

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                    acceptance_rate,
                    draft_sample_count,
                    target_sample_count,
                    thrown_away_count,
                    accepted_count,
                    resampled_count,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
        "history_ids": history_ids,
        "history_logits_target": target_model_cache._logits_history,
        "history_logits_draft": draft_model_cache._logits_history,
        "acceptance_rate": acceptance_rate,
        "draft_sample_count": draft_sample_count,
        "target_sample_count": target_sample_count,
        "thrown_away_count": thrown_away_count,
        "accepted_count": accepted_count,
        "resampled_count": resampled_count,
    }

    return input_ids, ret_arg_handle


def dynamic_speculative_sampling(
    input_ids,
    tokenizer,
    N,
    temperature,
    draft_model=None,
    target_model=None,
    max_gamma=4,
    random_seed=False,
    record_speculation=False,
    logfile_name=None,
    **kwargs,
):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    prefill_shape = input_ids.shape[1]
    ret_arg_handle = {}

    draft_model_cache = KVCacheModel(draft_model, temperature)
    target_model_cache = KVCacheModel(target_model, temperature)

    if record_speculation:
        history_ids = torch.zeros((1, 2 * T), dtype=torch.long)
        draft_sample_count = 0
        target_sample_count = 0
        thrown_away_count = 0
        accepted_count = 0
        resampled_count = 0

    gamma = max_gamma
    while input_ids.shape[1] < T:
        prefix_len = input_ids.shape[1]

        if draft_model_cache.prefill_cache_time == -1:
            prefill_start = torch_timer()
            draft_model_cache.generate_cache_time = prefill_start
            target_model_cache.generate_cache_time = prefill_start

            x = draft_model_cache.generate(input_ids, gamma)
            _ = target_model_cache.generate(x, 1)

            prefill_end = target_model_cache.prefill_cache_time
            gen_start = prefill_end
        else:
            if gamma == 0:
                _ = target_model_cache.generate(input_ids, 1)
            else:
                x = draft_model_cache.generate(input_ids, gamma)
                _ = target_model_cache.generate(x, 1)

        if record_speculation:
            draft_sample_count += gamma
            target_sample_count += gamma + 1

        n = prefix_len + gamma - 1

        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device=device)
            j = x[:, prefix_len + i]

            if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (
                draft_model_cache._prob_history[:, prefix_len + i - 1, j]
            ):
                # reject
                n = prefix_len + i - 1

                if record_speculation:
                    thrown_away_count += gamma - i
                break

            if record_speculation:
                accepted_count += 1
                history_ids[0, prefix_len + i - 1] = 1

        input_ids = x[:, : n + 1]

        draft_model_cache.rollback(n + 1)

        if n < prefix_len + gamma - 1:
            # reject someone, sample from the pos n
            q_minus_p = max_fn(
                target_model_cache._prob_history[:, n, :]
                - draft_model_cache._prob_history[:, n, :]
            )
            t = sample(q_minus_p)
            # prob = target_model_cache._prob_history[0, n, t] - draft_model_cache._prob_history[0, n, t]
            target_model_cache.rollback(n + 1)
            gamma = max(0, gamma - 1)

            if record_speculation:
                history_ids[0, n] = 2
        else:
            # all approx model decoding accepted
            t = sample(target_model_cache._prob_history[:, -1, :])

            # increase gamma
            gamma = min(gamma + 1, max_gamma)
            if record_speculation:
                target_sample_count += 1
                history_ids[0, n] = 3
            target_model_cache.rollback(n + 2)

        if record_speculation:
            resampled_count += 1

        input_ids = torch.cat((input_ids, t), dim=1)

        # exit if eos was previously seen or if the current token is eos
        if torch.any(input_ids[:-gamma] == tokenizer.eos_token_id):
            print("EOS found, exiting early")
            break

    gen_end = torch_timer()

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time
    acceptance_rate = accepted_count / draft_sample_count

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                    acceptance_rate,
                    draft_sample_count,
                    target_sample_count,
                    thrown_away_count,
                    accepted_count,
                    resampled_count,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
        "history_ids": history_ids,
        "history_logits_target": target_model_cache._logits_history,
        "history_logits_draft": draft_model_cache._logits_history,
        "acceptance_rate": acceptance_rate,
        "draft_sample_count": draft_sample_count,
        "target_sample_count": target_sample_count,
        "thrown_away_count": thrown_away_count,
        "accepted_count": accepted_count,
        "resampled_count": resampled_count,
    }

    return input_ids, ret_arg_handle


def dynamic_speculative_sampling_history(
    input_ids,
    tokenizer,
    N,
    temperature,
    draft_model=None,
    target_model=None,
    max_gamma=4,
    random_seed=False,
    record_speculation=True,
    logfile_name=None,
    **kwargs,
):
    """Speculative sampling with a global history and local"""
    token_cache = {}
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    prefill_shape = input_ids.shape[1]
    ret_arg_handle = {}

    # fill the token cache with counts of each token
    for i in range(input_ids.shape[1]):
        token = input_ids[0, i].item()
        token_cache[token] = token_cache.get(token, 0) + 1

    # # sort by frequency
    token_cache = dict(
        sorted(token_cache.items(), key=lambda item: item[1], reverse=True)
    )
    print(token_cache)

    draft_model_cache = KVCacheModel(draft_model, temperature)
    target_model_cache = KVCacheModel(target_model, temperature)

    if record_speculation:
        history_ids = torch.zeros((1, 2 * T), dtype=torch.long)
        draft_sample_count = 0
        target_sample_count = 0
        thrown_away_count = 0
        accepted_count = 0
        resampled_count = 0

    gamma = max_gamma
    while input_ids.shape[1] < T:
        prefix_len = input_ids.shape[1]

        if draft_model_cache.prefill_cache_time == -1:
            prefill_start = torch_timer()
            draft_model_cache.generate_cache_time = prefill_start
            target_model_cache.generate_cache_time = prefill_start

            x = draft_model_cache.generate(input_ids, gamma)
            _ = target_model_cache.generate(x, 1)

            prefill_end = target_model_cache.prefill_cache_time
            gen_start = prefill_end
        else:
            if gamma == 0:
                _ = target_model_cache.generate(input_ids, 1)
            else:
                x = draft_model_cache.generate(input_ids, gamma)
                _ = target_model_cache.generate(x, 1)

        if record_speculation:
            draft_sample_count += gamma
            target_sample_count += gamma + 1

        n = prefix_len + gamma - 1

        cache_hit = False

        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device=device)
            # r = 0.5
            j = x[:, prefix_len + i]

            cache_hit = cache_hit or j.item() in token_cache

            if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (
                draft_model_cache._prob_history[:, prefix_len + i - 1, j]
            ):
                # reject
                n = prefix_len + i - 1

                if record_speculation:
                    thrown_away_count += gamma - i
                break

            if record_speculation:
                accepted_count += 1
                history_ids[0, prefix_len + i - 1] = 1

        input_ids = x[:, : n + 1]

        draft_model_cache.rollback(n + 1)

        # add tokens to token_cache
        for i in range(prefix_len, n + 1):
            token = input_ids[0, i].item()
            token_cache[token] = token_cache.get(token, 0) + 1

        if n < prefix_len + gamma - 1:
            # reject someone, sample from the pos n
            q_minus_p = max_fn(
                target_model_cache._prob_history[:, n, :]
                - draft_model_cache._prob_history[:, n, :]
            )
            t = sample(q_minus_p)
            target_model_cache.rollback(n + 1)

            if record_speculation:
                history_ids[0, n] = 2

        else:
            # all approx model decoding accepted
            t = sample(target_model_cache._prob_history[:, -1, :])
            if record_speculation:
                target_sample_count += 1
                # target resampled as 3
                history_ids[0, n] = 3
            target_model_cache.rollback(n + 2)

        if record_speculation:
            resampled_count += 1

        input_ids = torch.cat((input_ids, t), dim=1)

        # exit if eos was previously seen or if the current token is eos
        if torch.any(input_ids[:-gamma] == tokenizer.eos_token_id):
            print("EOS found, exiting early")
            break

        accepted = n < prefix_len + gamma - 1
        if cache_hit and accepted:
            gamma = min(gamma + 2, max_gamma)
        elif cache_hit:
            gamma = min(gamma + 1, max_gamma)
        else:
            gamma = max(1, gamma - 1)

    gen_end = torch_timer()

    token_cache = dict(
        sorted(token_cache.items(), key=lambda item: item[1], reverse=True)
    )
    print(token_cache)

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time
    acceptance_rate = accepted_count / draft_sample_count

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                    acceptance_rate,
                    draft_sample_count,
                    target_sample_count,
                    thrown_away_count,
                    accepted_count,
                    resampled_count,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
        "history_ids": history_ids,
        "history_logits_target": target_model_cache._logits_history,
        "history_logits_draft": draft_model_cache._logits_history,
        "acceptance_rate": acceptance_rate,
        "draft_sample_count": draft_sample_count,
        "target_sample_count": target_sample_count,
        "thrown_away_count": thrown_away_count,
        "accepted_count": accepted_count,
        "resampled_count": resampled_count,
    }

    return input_ids, ret_arg_handle


def perceptron_predictor(
    input_ids,
    tokenizer,
    N,
    temperature,
    draft_model=None,
    target_model=None,
    max_gamma=4,
    random_seed=False,
    record_speculation=True,
    logfile_name=None,
    **kwargs,
):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    prefill_shape = input_ids.shape[1]
    ret_arg_handle = {}

    # Initialize Q-learning agent
    from src.predictor import QLearningAgent

    actions = [
        i for i in range(-1, 2)
    ]  # actions correspond to change in gamma: decrease, no change, increase
    agent = kwargs.get("agent", QLearningAgent(actions))

    draft_model_cache = KVCacheModel(draft_model, temperature)
    target_model_cache = KVCacheModel(target_model, temperature)

    if record_speculation:
        history_ids = torch.zeros((1, 2 * T), dtype=torch.long)
        draft_sample_count = 0
        target_sample_count = 0
        thrown_away_count = 0
        accepted_count = 0
        resampled_count = 0

    gamma = max_gamma
    while input_ids.shape[1] < T:
        accepted = True
        prefix_len = input_ids.shape[1]

        if draft_model_cache.prefill_cache_time == -1:
            prefill_start = torch_timer()
            draft_model_cache.generate_cache_time = prefill_start
            target_model_cache.generate_cache_time = prefill_start

            x = draft_model_cache.generate(input_ids, gamma)
            _ = target_model_cache.generate(x, 1)

            prefill_end = target_model_cache.prefill_cache_time
            gen_start = prefill_end
        else:
            if gamma == 0:
                _ = target_model_cache.generate(input_ids, 1)
            else:
                x = draft_model_cache.generate(input_ids, gamma)
                _ = target_model_cache.generate(x, 1)

        if record_speculation:
            draft_sample_count += gamma
            target_sample_count += gamma + 1

        n = prefix_len + gamma - 1

        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device=device)
            j = x[:, prefix_len + i]

            if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (
                draft_model_cache._prob_history[:, prefix_len + i - 1, j]
            ):
                # reject
                n = prefix_len + i - 1

                accepted = False
                if record_speculation:
                    thrown_away_count += gamma - i
                break

            if record_speculation:
                accepted_count += 1
                history_ids[0, prefix_len + i - 1] = 1

        input_ids = x[:, : n + 1]

        draft_model_cache.rollback(n + 1)

        if n < prefix_len + gamma - 1:
            # reject someone, sample from the pos n
            q_minus_p = max_fn(
                target_model_cache._prob_history[:, n, :]
                - draft_model_cache._prob_history[:, n, :]
            )
            t = sample(q_minus_p)
            # prob = target_model_cache._prob_history[0, n, t] - draft_model_cache._prob_history[0, n, t]
            target_model_cache.rollback(n + 1)

            if record_speculation:
                history_ids[0, n] = 2
        else:
            # all approx model decoding accepted
            t = sample(target_model_cache._prob_history[:, -1, :])

            if record_speculation:
                target_sample_count += 1
                history_ids[0, n] = 3
            target_model_cache.rollback(n + 2)

        if record_speculation:
            resampled_count += 1

        input_ids = torch.cat((input_ids, t), dim=1)

        # Get current state (you can define state as you wish, e.g., current gamma value)
        state = gamma
        # Get action from agent (change in gamma)
        action = agent.get_action(state)
        # Apply action to get new state and reward
        gamma = max(
            1, min(gamma + action, max_gamma)
        )  # apply action and ensure gamma is within bounds
        next_state = gamma
        reward = 1 if accepted else -1  # reward is 1 if accepted, -1 if rejected
        # Learn from this action
        agent.learn(state, action, reward, next_state)

        # exit if eos was previously seen or if the current token is eos
        if torch.any(input_ids[:-gamma] == tokenizer.eos_token_id):
            print("EOS found, exiting early")
            break

    gen_end = torch_timer()

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time
    acceptance_rate = accepted_count / draft_sample_count

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                    acceptance_rate,
                    draft_sample_count,
                    target_sample_count,
                    thrown_away_count,
                    accepted_count,
                    resampled_count,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
        "history_ids": history_ids,
        "history_logits_target": target_model_cache._logits_history,
        "history_logits_draft": draft_model_cache._logits_history,
        "acceptance_rate": acceptance_rate,
        "draft_sample_count": draft_sample_count,
        "target_sample_count": target_sample_count,
        "thrown_away_count": thrown_away_count,
        "accepted_count": accepted_count,
        "resampled_count": resampled_count,
    }

    return input_ids, ret_arg_handle

def window_scheduler(
    input_ids,
    tokenizer,
    N,
    temperature,
    draft_model=None,
    target_model=None,
    max_gamma=4,
    random_seed=False,
    record_speculation=True,
    logfile_name=None,
    **kwargs,
):
    n = input_ids.shape[1]
    T = input_ids.shape[1] + N
    prefill_shape = input_ids.shape[1]
    ret_arg_handle = {}



    draft_model_cache = KVCacheModel(draft_model, temperature)
    target_model_cache = KVCacheModel(target_model, temperature)

    if record_speculation:
        history_ids = torch.zeros((1, 2 * T), dtype=torch.long)
        draft_sample_count = 0
        target_sample_count = 0
        thrown_away_count = 0
        accepted_count = 0
        resampled_count = 0

    gamma = max_gamma

    # a : LM decoding time
    # b : verification time
    # r : single-step speculation accuracy
    
    from collections import deque
    context_window_length = kwargs.get("context_window_length", 7)
    truncated_accuracy = kwargs.get("truncated_accuracy", 0.8) 
    max_length = context_window_length # the approximation window is 5
    gamma_list = deque(maxlen = max_length)
    gammas = []
    #accepted_count_list = deque(maxlen = max_length)
    accepted_list = deque(maxlen = max_length)
    #a = 0
    #b = 0
    a = deque(maxlen=max_length)
    b = deque(maxlen=max_length)
    r = 0

    from timeit import default_timer as timer

    while input_ids.shape[1] < T:

        #gamma_list.append(gamma)
        gammas.append(gamma)

        accepted = True
        prefix_len = input_ids.shape[1]

        if draft_model_cache.prefill_cache_time == -1:
            prefill_start = torch_timer()
            draft_model_cache.generate_cache_time = prefill_start
            target_model_cache.generate_cache_time = prefill_start

            start = timer()
            x = draft_model_cache.generate(input_ids, gamma)
            end = timer()
            #print("elapsed time for decoding (ms) : ", end - start)
            a.append((end - start) / gamma)

            ### b estimator ###

            start = timer()
            _ = target_model_cache.generate(x, 1)
            end = timer()
            #b = end - start
            b.append(end - start)

            #print("elapsed time for verification (ms) : ", end - start)

            prefill_end = target_model_cache.prefill_cache_time
            gen_start = prefill_end
        else:
            if gamma == 0:
                #print('**********')
                x = input_ids
                _ = target_model_cache.generate(x, 1)
            else:

                ### a estimator ###

                start = timer()
                x = draft_model_cache.generate(input_ids, gamma)
                end = timer()
                #print("elapsed time for decoding (ms) : ", end - start)
                #a = (end - start) / gamma
                a.append((end - start) / gamma)
                

                ### b estimator ###

                start = timer()
                _ = target_model_cache.generate(x, 1)
                end = timer()
                #print("elapsed time for verification (ms) : ", end - start)
                #b = end - start
                b.append(end - start)

        if record_speculation:
            draft_sample_count += gamma
            target_sample_count += gamma + 1

        n = prefix_len + gamma - 1

        this_step_accepted_count = 0
        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device=device)
            #r = 0.6
            j = x[:, prefix_len + i]

            if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / (
                draft_model_cache._prob_history[:, prefix_len + i - 1, j]
            ):
                # reject
                n = prefix_len + i - 1

                accepted = False
                if record_speculation:
                    thrown_away_count += gamma - i
                break
            
            #this_step_accepted_count += 1

            if record_speculation:
                accepted_count += 1
                history_ids[0, prefix_len + i - 1] = 1

        #accepted_count_list.append(this_step_accepted_count)
        input_ids = x[:, : n + 1]

        draft_model_cache.rollback(n + 1)

        if n < prefix_len + gamma - 1:

            accepted_list.append(1)
            gamma_list.append(n - prefix_len + 1)

            # reject someone, sample from the pos n
            q_minus_p = max_fn(
                target_model_cache._prob_history[:, n, :]
                - draft_model_cache._prob_history[:, n, :]
            )
            t = sample(q_minus_p)
            # prob = target_model_cache._prob_history[0, n, t] - draft_model_cache._prob_history[0, n, t]
            target_model_cache.rollback(n + 1)

            if record_speculation:
                history_ids[0, n] = 2
        else:

            if gamma != 0:
                gamma_list.append(gamma)
                accepted_list.append(0)
            
            # all approx model decoding accepted
            t = sample(target_model_cache._prob_history[:, -1, :])

            if record_speculation:
                target_sample_count += 1
                history_ids[0, n] = 3
            target_model_cache.rollback(n + 2)

        if record_speculation:
            resampled_count += 1

        input_ids = torch.cat((input_ids, t), dim=1)

        # Apply action to get new state and reward
        #gamma = max(
        #    1, min(gamma + action, max_gamma)
        #)  # apply action and ensure gamma is within bounds



        #print(r)

        #print(n, calculate_average(a), calculate_average(b))

        a_hat = calculate_average(a)
        b_hat = calculate_average(b)
        #print(n, a_hat, b_hat)
        #b_hat = get_b_hat(n)
        #print(n, a_hat, b_hat)

        if n >200 and a_hat > b_hat :
        #if a_hat > b_hat:
            gamma = 0
        else:
            # gamma = argmax(1 - r^gamma)/(1 - r)(gamma * a + b)
            gamma = 1
            max_value = 0
            #print(gamma_list)
            if gamma_list == None:
                r = 1
            else:
                r = sum(gamma_list) / (sum(gamma_list) + sum(accepted_list))


            if draft_model == 'bigscience/bloom-560m':
                r = min(r, 0.8)
            elif draft_model == "bigscience/bloom-1b1":
                r = min(r, 0.97)
            else:
                r = min(r, truncated_accuracy)

            r = min(r, truncated_accuracy)

            for i in range (max_gamma):
                s = i + 1
                objective = (1 - pow(r, s)) / (1 - r) / (s * a_hat + b_hat)
                if objective > max_value:
                    max_value = objective
                    gamma = s


        # exit if eos was previously seen or if the current token is eos
        if torch.any(input_ids[:-gamma] == tokenizer.eos_token_id):
            print("EOS found, exiting early")
            break

    gen_end = torch_timer()

    # Generate statistics
    prefill_tokens = prefill_shape
    generate_tokens = input_ids.shape[1] - prefill_shape
    total_tokens = input_ids.shape[1]
    prefill_time = prefill_end - prefill_start
    generate_time = gen_end - gen_start
    total_time = gen_end - prefill_start
    prefill_tok_per_sec = prefill_shape / prefill_time
    generate_tok_per_sec = (input_ids.shape[1] - prefill_shape) / generate_time
    total_tok_per_sec = total_tokens / total_time
    acceptance_rate = accepted_count / draft_sample_count
    print('gamma list : ')
    print(gammas)
    average_gamma = sum(gammas) / len(gammas)
    print('average gamma = ', average_gamma)
    print('number of speculation = ', len(gammas))

    if logfile_name:
        with open(logfile_name, "a") as f:
            writer = csv.writer(f)
            writer.writerow(
                [
                    prefill_tokens,
                    generate_tokens,
                    total_tokens,
                    prefill_time,
                    generate_time,
                    total_time,
                    prefill_tok_per_sec,
                    generate_tok_per_sec,
                    total_tok_per_sec,
                    acceptance_rate,
                    draft_sample_count,
                    target_sample_count,
                    thrown_away_count,
                    accepted_count,
                    resampled_count,
                    average_gamma,
                ]
            )

    ret_arg_handle = {
        "prefill_tokens": prefill_shape,
        "generate_tokens": generate_tokens,
        "total_tokens": total_tokens,
        "prefill_time": prefill_time,
        "generate_time": generate_time,
        "total_time": total_time,
        "prefill_tok_per_sec": prefill_tok_per_sec,
        "generate_tok_per_sec": generate_tok_per_sec,
        "total_tok_per_sec": total_tok_per_sec,
        "history_ids": history_ids,
        "history_logits_target": target_model_cache._logits_history,
        "history_logits_draft": draft_model_cache._logits_history,
        "acceptance_rate": acceptance_rate,
        "draft_sample_count": draft_sample_count,
        "target_sample_count": target_sample_count,
        "thrown_away_count": thrown_away_count,
        "accepted_count": accepted_count,
        "resampled_count": resampled_count,
        "average_gamma": average_gamma,
    }

    return input_ids, ret_arg_handle