import copy
import os
import classifier_lib
import torch
import transformers
import ujson as json
import time
from tqdm import tqdm
from dataclasses import dataclass
import tyro
import infer_utils
import deepseek_utils
from transformers import DynamicCache, OffloadedCache
import benchmark_data
import accuracy_utils
import process_shards
from functools import partial


SEQLEN_MULTIPLE = 16


def get_token_ids(list_of_token_strs, tokenizer):
    ids = tokenizer.batch_encode_plus(list_of_token_strs, add_special_tokens=False)['input_ids']
    out = []
    for x in ids:
        if len(x) == 1:
            out.append(x[0])
        else:
            out.append(None)
    return out


def repeat(xs, n):
    out = []
    for i in range(len(xs)):
        out.extend([xs[i]] * n)
    return out


def unflatten(xs, n):
    return [xs[i:i+n] for i in range(0, len(xs), n)]


def compute_correctness(gt, pred):
    forward = process_shards.equivalence_relation(gt, pred)
    backward = process_shards.equivalence_relation(gt, pred, reverse=True)
    return forward or backward


def offloaded_cache_from_dynamic_cache(cache: DynamicCache):
    result = OffloadedCache()
    for layer_idx in range(len(cache)):
        key_states, value_states = cache[layer_idx]
        result.update(key_states=key_states, value_states=value_states, layer_idx=layer_idx)
    return result


def get_common_padding(x, pad_token_id, side="left"):
    # x: [bs, N]
    # Returns the maximum number of common padding tokens across all sequences

    # Use our per-sequence padding counting function
    per_sequence_padding = count_padding_per_sequence(x, pad_token_id, side)

    # The common padding is the minimum padding across all sequences
    return per_sequence_padding.min().item()


def count_padding_per_sequence(x, pad_token_id, side):
    # x: [bs, N]
    # Create a mask where True indicates a pad token
    is_pad = (x == pad_token_id)

    if side == "left":
        # For left padding, find the first non-pad token for each sequence
        bs, seq_len = x.shape
        # Get indices of first non-pad token for each sequence
        first_nonpad = (~is_pad).int().argmax(dim=1)
        # Handle edge case where a sequence is all padding
        all_pad_mask = is_pad.all(dim=1)
        first_nonpad = torch.where(all_pad_mask, seq_len, first_nonpad)
        return first_nonpad

    elif side == "right":
        # For right padding, flip the sequences and find the first non-pad token
        bs, seq_len = x.shape
        # Flip the is_pad tensor along sequence dimension
        flipped_is_pad = torch.flip(is_pad, dims=[1])
        # Get indices of first non-pad token in the flipped tensor
        first_nonpad_flipped = (~flipped_is_pad).int().argmax(dim=1)
        # Handle edge case where a sequence is all padding
        all_pad_mask = is_pad.all(dim=1)
        first_nonpad_flipped = torch.where(all_pad_mask, seq_len, first_nonpad_flipped)
        # Convert to count of padding tokens from the right
        return first_nonpad_flipped

    else:
        raise ValueError("'side' must be 'left' or 'right'")


def remove_common_right_padding(input_ids, attention_mask):
    common_right_pads = get_common_padding(attention_mask, pad_token_id=0, side="right")
    common_right_pads = (common_right_pads // SEQLEN_MULTIPLE) * SEQLEN_MULTIPLE
    if common_right_pads >= SEQLEN_MULTIPLE:
        input_ids = input_ids[:, :-common_right_pads]
        attention_mask = attention_mask[:, :-common_right_pads]
        # print(f"Removed {common_right_pads} common right padding tokens...")
    return input_ids, attention_mask


def roll_input_and_kv_cache(current_input_ids, current_attention_mask, past_kv_cache, init_prompt_seqlen):
    left_shifts = count_padding_per_sequence(current_attention_mask, pad_token_id=0, side="left")
    right_shifts = count_padding_per_sequence(current_attention_mask, pad_token_id=0, side="right")
    shifts = right_shifts
    N = current_input_ids.shape[0]
    for i in range(N):
        shift_by = shifts[i].item()
        current_input_ids[i] = torch.roll(current_input_ids[i], shift_by)
        current_attention_mask[i] = torch.roll(current_attention_mask[i], shift_by)
        init_prompt_seqlen[i] += shift_by

    left_shifts2 = count_padding_per_sequence(current_attention_mask, pad_token_id=0, side="left")
    right_shifts2 = count_padding_per_sequence(current_attention_mask, pad_token_id=0, side="right")
    assert torch.all(left_shifts2 == left_shifts + right_shifts)
    assert torch.all(right_shifts2 == 0)

    # remove the common left padding tokens
    common_left_pads = left_shifts2.min().item()
    common_left_pads = (common_left_pads // SEQLEN_MULTIPLE) * SEQLEN_MULTIPLE
    if common_left_pads >= SEQLEN_MULTIPLE:
        current_input_ids = current_input_ids[:, common_left_pads:]
        current_attention_mask = current_attention_mask[:, common_left_pads:]
        init_prompt_seqlen = [x - common_left_pads for x in init_prompt_seqlen]
        # print(f"Removed {common_left_pads} common left padding tokens...")

    # also roll the kv cache
    for l in range(len(past_kv_cache)):
        k, v = past_kv_cache[l]  # [bs, n_kv_heads, seqlen, dim_per_head]
        for i in range(N):
            shift_by = shifts[i].item()
            k[i] = torch.roll(k[i], shift_by, dims=1)
            v[i] = torch.roll(v[i], shift_by, dims=1)
        past_kv_cache.key_cache[l] = k[:, :, common_left_pads:]
        past_kv_cache.value_cache[l] = v[:, :, common_left_pads:]
    return current_input_ids, current_attention_mask, past_kv_cache, init_prompt_seqlen


def batch_select_indices(self, indices: torch.Tensor):
    """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
    for layer_idx in range(len(self)):
        device = self.key_cache[layer_idx].device
        indices = indices.to(device)
        self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
        self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
# patch batch_select_indices to work with OffloadedCache
OffloadedCache.batch_select_indices = batch_select_indices


@dataclass
class Args:
    benchmark: str = "aime-24"
    inference_impl: str = "efficient"
    # flash_attention_2 doesn't work with 4d attn mask, so need inference_impl = naive
    attention_impl: str = "flash_attention_2"
    seed: int = 1337
    piref_model: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    disable_lora: bool = True
    qsharp_model: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
    qsharp_ckpt_path: str = "qsharp_ckpts/bs_128_lr_0.0006_lora_256_dropout_0.05_seed_1337/model_4000"
    output_path: str = "inference_outputs.jsonl"
    compile: bool = False
    piref_gpu_util: float = 0.5

    batch_size: int = 32
    max_length: int = 16384
    num_blocks: int = 8
    temperature: float = 0.6
    top_p: float = 0.95

    block_size: int = 128
    n_newline_as_block: int = 0  # generate until newline


    def __post_init__(self):
        print("qsharp_ckpt_path: ", self.qsharp_ckpt_path)
        print("output_path: ", self.output_path)
        os.makedirs(os.path.dirname(self.output_path), exist_ok=True)


def left_pad_collate_fn(input_ids_list: list[list[int]], pad_token_id):
    # Round to nearest multiple of SEQLEN_MULTIPLE
    max_len = max(len(input_ids) for input_ids in input_ids_list)
    max_len = ((max_len + SEQLEN_MULTIPLE - 1) // SEQLEN_MULTIPLE) * SEQLEN_MULTIPLE
    padded_input_ids = []
    padded_attention_masks = []

    # Iterate through each example in the batch
    for input_ids in input_ids_list:
        attention_mask = [1] * len(input_ids)
        pad_length = max_len - len(input_ids)

        # Left pad: add padding tokens to the beginning of the list
        padded_ids = [pad_token_id] * pad_length + input_ids
        padded_mask = [0] * pad_length + attention_mask

        padded_input_ids.append(padded_ids)
        padded_attention_masks.append(padded_mask)

    # Convert lists to torch tensors
    batch_input_ids = torch.tensor(padded_input_ids)
    batch_attention_masks = torch.tensor(padded_attention_masks)
    return {
        'input_ids': batch_input_ids,
        'attention_mask': batch_attention_masks,
    }


def get_newline_token_ids(tokenizer):
    all_newline_token_strs = [s for s in tokenizer.batch_decode(list(range(tokenizer.vocab_size))) if s.endswith("\n\n")]
    # omit the ones which are still in the middle of a thought
    omit_newline_strs = [":\n\n", ",\n\n"]
    newline_token_strs = [s for s in all_newline_token_strs if not any(s.endswith(x) for x in omit_newline_strs)]
    newline_token_ids = get_token_ids(newline_token_strs, tokenizer)
    return newline_token_ids


def get_eos_token_id(tokenizer):
    eos_token_id = tokenizer.convert_tokens_to_ids(["<｜end▁of▁sentence｜>"])[0]
    return eos_token_id


def generate(
    piref_model, tokenizer, classifier_model, input_ids: list[list[int]],
    max_length: int, temperature: float, top_p: float, num_blocks: int, block_size: int,
    n_newline_as_block: int = 0, newline_token_ids: list[int] = None, stop_token_ids: list[int] = None,
    seed: int = 1337,
):
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    device = 'cuda'
    device_type = 'cuda'
    dtype = torch.bfloat16
    torch.set_float32_matmul_precision('high')
    torch.manual_seed(seed)
    if device_type == 'cuda':
        torch.cuda.manual_seed(seed)

    if stop_token_ids is None:
        stop_token_ids = [get_eos_token_id(tokenizer)]

    if newline_token_ids is None and n_newline_as_block:
        newline_token_ids = get_newline_token_ids(tokenizer)

    collate_batch = left_pad_collate_fn(input_ids, pad_token_id=tokenizer.pad_token_id)
    current_input_ids = collate_batch['input_ids'].to(device).long()
    current_attention_mask = collate_batch['attention_mask'].to(device).long()
    batch_size, init_prompt_seqlen = current_input_ids.shape
    assert init_prompt_seqlen % SEQLEN_MULTIPLE == 0, f"init_prompt_seqlen {init_prompt_seqlen} is not a multiple of {SEQLEN_MULTIPLE}"
    init_prompt_seqlen = [init_prompt_seqlen] * batch_size
    generated_ids = [None] * batch_size
    undone_indices_map = list(range(batch_size))

    # begin processing. first forward pass to get past_key_values
    with torch.autocast(device_type=device_type, dtype=dtype):
        classifier_outputs = classifier_model(input_ids=current_input_ids, attention_mask=current_attention_mask)
    past_key_values = classifier_outputs.past_key_values
    assert isinstance(past_key_values, DynamicCache)
    past_key_values = offloaded_cache_from_dynamic_cache(past_key_values)
    bar = tqdm(total=max_length)
    infer_engine_dt = infer_classifier_dt = 0
    while len(undone_indices_map) > 0:
        t0 = time.time()
        current_batch_size = len(undone_indices_map)
        infer_inputs = [current_input_ids[i][current_attention_mask[i].bool()].tolist() for i in range(current_batch_size)]
        infer_inputs = repeat(infer_inputs, num_blocks)
        sampling_params = dict(
            temperature=temperature,
            top_p=top_p,
            skip_special_tokens=False,
        )
        if n_newline_as_block:
            sampling_params["stop_token_ids"] = newline_token_ids + stop_token_ids
            continuation_ids = [list() for _ in range(current_batch_size * num_blocks)]
            max_new_tokens_list = [block_size] * (current_batch_size * num_blocks)  # >0 iff not done
            for newline_idx in range(n_newline_as_block):
                processing_indices = [i for i in range(current_batch_size * num_blocks) if max_new_tokens_list[i] > 0]
                cur_infer_inputs = [infer_inputs[i] + continuation_ids[i] for i in processing_indices]
                sampling_params_list = [{**sampling_params, "max_new_tokens": max_new_tokens_list[i]} for i in processing_indices]
                infer_outputs = piref_model.generate(input_ids=cur_infer_inputs, sampling_params=sampling_params_list)
                cur_continuation_ids = [x['output_ids'] for x in infer_outputs]
                # concat the new tokens to the existing continuation_ids
                # if done, set max_new_tokens to 0
                for i, idx in enumerate(processing_indices):
                    continuation_ids[idx].extend(cur_continuation_ids[i])
                    max_new_tokens_list[idx] -= len(cur_continuation_ids[i])
                    if cur_continuation_ids[i][-1] in stop_token_ids:
                        max_new_tokens_list[idx] = 0
                    if any(max_new_tokens_list[i] < 0 for i in processing_indices):
                        print("Warning: max_new_tokens_list < 0. This may indicate a bug.")
                        breakpoint()
                if all(max_new_tokens_list[i] == 0 for i in processing_indices):
                    break

        else:
            sampling_params["max_new_tokens"] = block_size
            sampling_params["stop_token_ids"] = stop_token_ids
            infer_outputs = piref_model.generate(input_ids=infer_inputs, sampling_params=sampling_params)
            continuation_ids = [x['output_ids'] for x in infer_outputs]

        max_continuation_len = max(len(x) for x in continuation_ids)

        # right pad the continuation_ids.
        continuation_ids = unflatten(continuation_ids, num_blocks)  # [bs, n, block_size]
        continuation_attention_masks = []
        for bs in range(current_batch_size):
            continuation_attention_masks.append([None] * num_blocks)
            for n in range(num_blocks):
                cur_continuation_len = len(continuation_ids[bs][n])
                pad_len = max_continuation_len - cur_continuation_len
                continuation_ids[bs][n] = continuation_ids[bs][n] + [tokenizer.pad_token_id] * pad_len
                continuation_attention_masks[bs][n] = [1] * cur_continuation_len + [0] * pad_len
        continuation_ids = torch.tensor(continuation_ids, device=device, dtype=torch.long)  # [bs, n, block_size]
        continuation_attention_mask = torch.tensor(continuation_attention_masks, device=device, dtype=torch.long)
        infer_engine_dt += time.time() - t0

        t0 = time.time()
        with torch.autocast(device_type=device_type, dtype=dtype):
            # copied_past_key_values = copy.deepcopy(past_key_values)
            classifier_outputs = classifier_model(
                attention_mask=current_attention_mask,
                continuation_ids=continuation_ids,
                continuation_attention_mask=continuation_attention_mask,
                past_key_values=past_key_values)
            # for l in range(len(past_key_values)):
            #     k, v = past_key_values[l]
            #     assert copied_past_key_values[l][0].shape == k.shape
            #     assert copied_past_key_values[l][1].shape == v.shape
            #     assert torch.all(torch.isclose(copied_past_key_values[l][0], k))
            #     assert torch.all(torch.isclose(copied_past_key_values[l][1], v))
            # TODO: transform for Q#
            classifier_scores = classifier_outputs.logits  # [bs, n]

        # pick best continuation indices
        best_continuation_index = torch.argmax(classifier_scores, dim=-1)  # [bs]
        chosen_continuation_ids = continuation_ids[torch.arange(current_batch_size), best_continuation_index]  # [bs, block_size]
        chosen_continuation_mask = continuation_attention_mask[torch.arange(current_batch_size), best_continuation_index]  # [bs, block_size]
        avg_continuation_len = chosen_continuation_mask.sum(dim=1).median().item()
        min_continuation_len = chosen_continuation_mask.sum(dim=1).min().item()
        # print(avg_continuation_len, chosen_continuation_mask.sum(dim=1))

        # remove any common right padding before updating kv cache
        before = chosen_continuation_mask.shape
        chosen_continuation_ids, chosen_continuation_mask = remove_common_right_padding(chosen_continuation_ids, chosen_continuation_mask)

        # update input_ids, attention_mask and kv cache with chosen continuation
        current_input_ids = torch.cat([current_input_ids, chosen_continuation_ids], dim=1)
        current_attention_mask = torch.cat([current_attention_mask, chosen_continuation_mask], dim=1)
        with torch.autocast(device_type=device_type, dtype=dtype):
            classifier_outputs = classifier_model(
                input_ids=chosen_continuation_ids, attention_mask=current_attention_mask, past_key_values=past_key_values)
            past_key_values = classifier_outputs.past_key_values

        # roll and remove common left padding
        if n_newline_as_block:
            current_input_ids, current_attention_mask, past_key_values, init_prompt_seqlen = roll_input_and_kv_cache(
                current_input_ids, current_attention_mask, past_key_values, init_prompt_seqlen)

        # check for dones and truncations
        done_mask = []
        for i in range(current_batch_size):
            # done when eos token is generated
            is_done = current_input_ids[i, -1] in stop_token_ids
            # truncate when response exceeds max_length tokens
            is_trunc = current_attention_mask[i][init_prompt_seqlen[i]:].sum().item() >= max_length
            if is_done or is_trunc:
                global_idx = undone_indices_map[i]
                generated_ids[global_idx] = current_input_ids[i, init_prompt_seqlen[i]:].tolist()
            done_mask.append(is_done or is_trunc)

        # remove them from current batch and update undone_indices_map
        done_mask = torch.tensor(done_mask, device=device, dtype=torch.bool)
        current_input_ids = current_input_ids[~done_mask]
        current_attention_mask = current_attention_mask[~done_mask]
        past_key_values.batch_select_indices(~done_mask)
        undone_indices_map = [x for i, x in enumerate(undone_indices_map) if not done_mask[i]]
        init_prompt_seqlen = [x for i, x in enumerate(init_prompt_seqlen) if not done_mask[i]]

        if min_continuation_len <= 1:
            print("Warning: min_continuation_len <= 1. This may indicate a bug.")

        del continuation_ids, continuation_attention_mask, classifier_outputs, classifier_scores, best_continuation_index, chosen_continuation_ids, chosen_continuation_mask
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        infer_classifier_dt += time.time() - t0

        bar.update(avg_continuation_len)
        bar.set_description(f"Queries left: {len(undone_indices_map):3d}/{batch_size}, max len: {current_input_ids.size(1):5d}, sgl: {infer_engine_dt:.2f}s, classifier: {infer_classifier_dt:.2f}s")
    bar.close()
    return generated_ids


@torch.no_grad()
def main(args: Args):
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    device = 'cuda'
    device_type = 'cuda'
    dtype = torch.bfloat16
    torch.set_float32_matmul_precision('high')
    torch.manual_seed(args.seed)
    if device_type == 'cuda':
        torch.cuda.manual_seed(args.seed)

    piref_model = infer_utils.InferenceEngine(
        engine_type="sgl",
        model=args.piref_model,
        dtype=dtype,
        gpu_memory_utilization=args.piref_gpu_util,
        random_seed=args.seed,
        skip_tokenizer_init=True,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.qsharp_model)
    tokenizer.padding_side = "left"  # left padding for generation
    model_loading_kwargs = dict(attn_implementation=args.attention_impl, torch_dtype=dtype, use_cache=True)
    if args.disable_lora:
        classifier = classifier_lib.Qwen2ForQSharp.from_pretrained(args.qsharp_ckpt_path, **model_loading_kwargs).to(device)
    else:
        import peft
        classifier = classifier_lib.Qwen2ForQSharp.from_pretrained(args.qsharp_model, **model_loading_kwargs)
        classifier = peft.PeftModel.from_pretrained(classifier, args.qsharp_ckpt_path).to(device)
        classifier = classifier.merge_and_unload()
    classifier.inference_impl = args.inference_impl
    classifier.eval()
    if args.compile:
        print("Compiling model...")
        t0 = time.time()
        classifier.model = torch.compile(classifier.model)
        print(f"Compiled model in {time.time() - t0:.2f} seconds.")
    print("Finished loading piref and classifier.")

    # has problem and answer columns
    dataset = benchmark_data.get_dataset(args.benchmark)
    def preprocess(example):
        formatted_problem = deepseek_utils.format_roll_in(example["problem"])
        input_ids = tokenizer(formatted_problem, add_special_tokens=False)["input_ids"]
        return {"problem": example["problem"], "answer": example["answer"], "input_ids": input_ids}
    dataset = dataset.map(preprocess)
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    # load existing outputs
    num_skip_batches = 0
    if os.path.exists(args.output_path):
        n_already_done = 0
        with open(args.output_path, "r") as f:
            for line in f:
                json.loads(line)
                n_already_done += 1
        if n_already_done == len(dataset):
            print(f"Output file {args.output_path} already contains all samples. Exiting...")
            exit(0)
        assert n_already_done % args.batch_size == 0, f"output file has {n_already_done} samples, which is not a multiple of batch_size {args.batch_size}"
        num_skip_batches = n_already_done // args.batch_size
        print(f"{n_already_done=}. So skipping {num_skip_batches} batches...")
    else:
        with open(args.output_path, "w") as f:
            f.write("")

    num_batches = len(loader)
    for batch_idx, batch in enumerate(loader):
        if batch_idx < num_skip_batches:
            print(f"Skipping batch {batch_idx}...")
            continue

        print(f"Starting batch {batch_idx}...")
        batch_t0 = time.time()
        generated_ids = generate(
            piref_model=piref_model, tokenizer=tokenizer, classifier_model=classifier, input_ids=batch["input_ids"],
            max_length=args.max_length, temperature=args.temperature, top_p=args.top_p, num_blocks=args.num_blocks, block_size=args.block_size,
            n_newline_as_block=args.n_newline_as_block,
        )
        generated_raw_texts = tokenizer.batch_decode(generated_ids)
        generated_solutions = [deepseek_utils.remove_thinking_text(x) for x in generated_raw_texts]
        processed_answers = [accuracy_utils.process_sample(x) for x in generated_solutions]
        gt_answers = batch["answer"]
        rewards = [compute_correctness(gt_a, pred_a) for gt_a, pred_a in zip(gt_answers, processed_answers, strict=True)]
        dt = time.time() - batch_t0
        with open(args.output_path, "a") as f:
            for i in range(batch_size):
                outputs = {
                    "batch_idx": batch_idx,
                    "problem": batch["problem"][i],
                    "generated_ids": generated_ids[i],
                    "generated_raw_text": generated_raw_texts[i],
                    "processed_answer": processed_answers[i],
                    "gt_answer": gt_answers[i],
                    "reward": rewards[i],
                    "dt": dt,
                }
                f.write(json.dumps(outputs) + "\n")
                f.flush()

        rewards = torch.tensor(rewards, dtype=torch.float32)
        mean_reward = rewards.mean().item()
        se_reward = 0 if len(rewards) == 1 else torch.std(rewards, correction=1) / (len(rewards) ** 0.5)
        print(f"batch {batch_idx} took {dt:.2f} secs. Reward: {mean_reward} ± {se_reward}")

        bar.close()
        del current_input_ids, current_attention_mask, past_key_values
        torch.cuda.empty_cache()
        torch.cuda.synchronize()


if __name__ == "__main__":
    args = tyro.cli(Args)
    main(args)