import os
import torch
import numpy as np
from src.utils import torch_timer
from src.sampling import (
    autoregressive_sampling,
    speculative_sampling,
    dynamic_speculative_sampling,
    dynamic_speculative_sampling_history,
    perceptron_predictor,
)


def inference(
    input_ids,
    tokenizer,
    n_tokens_to_generate,
    temperature,
    **kwargs,
):
    # Attach function handle based on mode
    mode = kwargs.get("mode", "autoregressive")
    if mode == "autoregressive":
        sampling_fn = autoregressive_sampling
    elif mode == "speculative" or mode == "upper_bound_speculative":
        sampling_fn = speculative_sampling
    elif mode == "dynamic_speculative":
        sampling_fn = dynamic_speculative_sampling
    elif mode == "dynamic_speculative_history":
        sampling_fn = dynamic_speculative_sampling_history
    elif mode == "perceptron_predictor":
        sampling_fn = perceptron_predictor

    logging_history = kwargs.get("logging_history", False)
    dataset = kwargs.get("dataset", None)
    target_model = kwargs.get("target_model", None)
    draft_model = kwargs.get("draft_model", None)
    gamma = kwargs.get("gamma", 4)

    output_ids, ret_args = sampling_fn(
        input_ids,
        tokenizer,
        n_tokens_to_generate,
        temperature,
        **kwargs,
    )
    time_elapsed = ret_args["total_time"]
    text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    if "history_ids" in ret_args:
        # history ids contains the tags of the tokens of whether they are accepted, rejected
        # 0 is white
        # 1 is accepted, and should be marked green
        # 2 is rejected, and should be marked red
        # 3 is resampled, and should be marked blue
        history_ids = ret_args["history_ids"]
        history_logits_draft = ret_args["history_logits_draft"]
        history_logits_target = ret_args["history_logits_target"]
        # print the values of the logits based on the output_ids

        # DEBUG compare
        # for i, token_id in enumerate(output_ids[0]):
        #     if i == 0 or i > history_logits_draft.shape[1]:
        #         continue
        #     print(f"Token {i} = draft {history_logits_draft[:, i-1, token_id].item()}, target {history_logits_target[:, i-1, token_id].item()}")
        #     print(f"{torch.argmax(history_logits_draft[:, i-1, :])}, {torch.argmax(history_logits_target[:, i-1, :])}, {token_id}")
        # color code the text on the CLI based on the history ids
        new_text = ""
        for i, token_id in enumerate(output_ids[0]):
            tok = tokenizer.decode(token_id, skip_special_tokens=True)
            # print token with no new line
            if history_ids[0, i] == 0:
                new_text += tok
            elif history_ids[0, i] == 1:
                new_text += f"\033[92m{tok}\033[0m"
            elif history_ids[0, i] == 2:
                new_text += f"\033[91m{tok}\033[0m"
            elif history_ids[0, i] == 3:
                new_text += f"\033[94m{tok}\033[0m"

            # Add a dark grew highlighted box to indicate separation between tokens
            new_text += "\033[100m \033[0m"

        # Only keep nonzeros in history_ids
        history_ids = history_ids[history_ids != 0]

        # save the history ids to a file, will need a dataset handle
        # if logging_history and mode == "perceptron_predictor":
        #     # make dirs
        #     os.makedirs("data", exist_ok=True)
        #     history_filename = f"data/{dataset}_history_ids_{gamma}.npy"

        #     # check if history_filename exists, if not,create new
        #     if not os.path.exists(history_filename):
        #         np.save(history_filename, history_ids.cpu().numpy().reshape(1, -1))
        #     else:
        #         # Load the existing history_ids
        #         history_ids_old = np.load(history_filename)

        #         # Convert the new history_ids to a numpy array
        #         history_ids_new = history_ids.cpu().numpy().reshape(1, -1)
        #         # print(history_ids_new.shape)

        #         # Pad the shorter array with -1
        #         if history_ids_old.shape[1] < history_ids_new.shape[1]:
        #             history_ids_old = np.pad(
        #                 history_ids_old,
        #                 (
        #                     (0, 0),
        #                     (0, history_ids_new.shape[1] - history_ids_old.shape[1]),
        #                 ),
        #                 "constant",
        #                 constant_values=-1,
        #             )
        #         elif history_ids_new.shape[1] < history_ids_old.shape[1]:
        #             history_ids_new = np.pad(
        #                 history_ids_new,
        #                 (
        #                     (0, 0),
        #                     (0, history_ids_old.shape[1] - history_ids_new.shape[1]),
        #                 ),
        #                 "constant",
        #                 constant_values=-1,
        #             )

        #         # Concatenate the old and new history_ids
        #         history_ids_save = np.concatenate(
        #             (history_ids_old, history_ids_new), axis=0
        #         )

        #         # Save the history_ids
        #         np.save(history_filename, history_ids_save)

        # gather the numbers of accepted, rejected, and resampled tokens
        accepted = torch.sum(history_ids == 1).item()
        rejected = torch.sum(history_ids == 2).item()
        resampled = torch.sum(history_ids == 3).item()
        gen_tokens = accepted + rejected + resampled
        print(f"Accepted = {accepted}, Rejected = {rejected}, Resampled = {resampled}")
        # print rates
        print(f"Acceptance rate = {accepted / gen_tokens:.2f}")
        print(f"Rejection rate = {rejected / gen_tokens:.2f}")
        print(f"Resampling rate = {resampled / gen_tokens:.2f}")
        print("Other stats")
        print("acceptance_rate", ret_args["acceptance_rate"])
        print("draft_sample_count", ret_args["draft_sample_count"])
        print("target_sample_count", ret_args["target_sample_count"])
        print("thrown_away_count", ret_args["thrown_away_count"])
        print("accepted_count", ret_args["accepted_count"])
        print("resampled_count", ret_args["resampled_count"])

        print(new_text)

    torch.cuda.empty_cache()

    print("########################################")
    print(f"Text generated for {mode}")
    print("----------------------------------------")
    print(f"Text = {text}")
    print("----------------------------------------")
    print(f"Performance for {mode}")
    print("----------------------------------------")
    print(f"            Time = {time_elapsed:.2f}s")
    print(f"          Tokens = {output_ids.shape[1]}")
    print(f"       E2E Tok/s = {output_ids.shape[1] / time_elapsed:.2f}")
    print(f"   Prefill Tok/s = {ret_args['prefill_tok_per_sec']:.2f}")
    print(f"Generation Tok/s = {ret_args['generate_tok_per_sec']:.2f}")
    print("----------------------------------------")
    print("########################################")
