import math
import os
from typing import Any, Dict, Literal, Optional, Union

import datasets
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import transformers
from argparse_dataclass import ArgumentParser
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.checkpoint import checkpoint
from torch.utils.data import Subset
from transformers import (DataCollatorForSeq2Seq, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments)
import transformers
from transformers.loss.loss_utils import ForCausalLMLoss
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

import wandb
from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage
from hip_attn.v1_3.models.llama import (LlamaAttention, LlamaConfig,
                                        LlamaForCausalLM)
from hip_research.dataset.openwebtext import OpenWebTextDataset
from hip_research.dataset.pg19 import PG19Dataset
from hip_research.dataset.pg19_long_qa import PG19LongQA
from hip_research.main.jobs.ppl import PplArgs, job_ppl
from hip_research.utils.long_train import Config, get_logger

log = get_logger()

TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"
WANDB_DISABLED = os.getenv("WANDB_MODE", "none") == "disabled"


def long_ppl(
    model,
    long_logits,
    short_logits,
    labels,
    ignore_index=-100,
    idx_select=False,
    alpha=2.0,
    beta=-2.0,
    long_logits_truth=None,
    block_index=None,
):
    if long_logits_truth is None:
        long_logits_truth = long_logits
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    B, S, V = long_logits_truth.shape

    # Shift so that tokens < n predict n
    if labels.shape[1] != long_logits_truth.shape[1] + 1:
        labels = nn.functional.pad(labels, (0, 1), value=ignore_index)

    shift_labels = labels[..., 1:].contiguous().to(long_logits_truth.device)

    # slice off last token because there is no label for it
    long_logp = long_logits_truth.gather(
        2, shift_labels.unsqueeze(-1).clamp(min=0)
    ) - torch.logsumexp(long_logits_truth, dim=-1, keepdim=True)
    short_logp = short_logits.gather(
        2, shift_labels.unsqueeze(-1).clamp(min=0)
    ) - torch.logsumexp(short_logits, dim=-1, keepdim=True)

    # likelihood ratio of p(x | long) / p(x | short)

    LSD = torch.exp(long_logp - short_logp).squeeze(-1) > alpha
    LCL = torch.exp(long_logp).squeeze(-1) > beta
    I = LSD * LCL

    I.masked_fill_(shift_labels == (ignore_index), 0)

    if idx_select:
        return torch.exp(long_logp - short_logp).squeeze()

    # Flatten the tokens
    long_logits = long_logits.view(-1, V)
    shift_labels = shift_labels.view(-1)

    ce_loss = nn.functional.cross_entropy(
        long_logits.float(),
        shift_labels,
        ignore_index=ignore_index,
        reduction="sum",
    )
    ce_count = (shift_labels != ignore_index).sum()

    long_ce_loss = nn.functional.cross_entropy(
        long_logits.float(),
        shift_labels,
        ignore_index=ignore_index,
        reduction="none",
    )

    long_ce_count = I.sum()
    long_ce_loss = (long_ce_loss * I).sum()
    return long_ce_loss, long_ce_count, ce_loss, ce_count


def evaluate(
    config,
    model,
    tokenizer,
    eval_dataset=None,
    ignore_keys=None,
    metric_key_prefix: str = "eval",
) -> Dict[str, float]:

    setattr(model, "no_lm_head", True)
    model.eval()

    ds = eval_dataset
    long_ppl_key = f"{metric_key_prefix}/long-ppl"
    ppl_key = f"{metric_key_prefix}/ppl"
    metrics = {long_ppl_key: 0, ppl_key: 0}
    N = 0

    with torch.no_grad():
        for inputs in ds:
            N += 1
            inputs, target = inputs["input_ids"].unsqueeze(0), inputs[
                "labels"
            ].unsqueeze(0)

            inputs = inputs[..., : config.seq_len].cuda()
            target = target[..., : config.seq_len].cuda()
            # print(f"eval: {inputs.size()=}")

            # Long CE =====================================================================================================
            if os.environ.get("USE_ATTN_POSTFIX", "0") == "1" or False:
                # raise NotImplementedError(
                #     "changed flow of this to not need USE_ATTN_POSTFIX==1, refactor this code"
                # )
                # long_loss, ce_loss = batched_recompute_ondemand(
                #     inputs, target, model, tokenizer, config)
                # long_loss, ce_loss = batched_recompute_ondemand_multi(
                #     inputs, target, model, tokenizer, config
                # )
                long_loss, ce_loss = batched_long_cross_entropy_attn_postfix(
                    inputs, target, model, tokenizer, config
                )
                # long_loss, ce_loss = batched_long_cross_entropy_attn_postfix_multi(
                #     inputs, target, model, tokenizer, config
                # )
            else:
                long_loss, ce_loss = batched_long_cross_entropy(
                    inputs, target, model, tokenizer, config
                )
                # ce_loss = batched_cross_entropy(
                #     inputs, target, model, tokenizer, config
                # )
            metrics[long_ppl_key] += long_loss.item()

            # NORMAL CE =====================================================================================================
            metrics[ppl_key] += ce_loss.item()

            print(
                f"[{N=}] long_ppl={np.exp(metrics[long_ppl_key] / N)} ppl={np.exp(metrics[ppl_key] / N)}"
            )

    # change ce into ppl for eval
    metrics[long_ppl_key] = np.exp(metrics[long_ppl_key] / N)
    metrics[ppl_key] = np.exp(metrics[ppl_key] / N)
    return metrics


def batched_cross_entropy(inputs, target, model, tokenizer, config):
    cutoff_point = (target >= 0).squeeze(0).nonzero()[0].item()
    # print(f"{cutoff_point=} {inputs.size()=}")

    outputs = model(
        inputs,
        # FIXME: padding need?
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        labels=target,
        use_cache=False,
    )

    loss, count, ignore_index = 0, 0, -100
    for b in range(cutoff_point, inputs.size(1), config.long_ce_block_size):
        logits = model.lm_head(outputs.logits[:, b : b + config.long_ce_block_size])
        logits = logits.float()
        labels = target[:, b : b + config.long_ce_block_size + 1].to(logits.device)

        # Shift so that tokens < n predict n
        if labels.size(1) != logits.size(1) + 1:
            labels = nn.functional.pad(labels, (0, 1), value=ignore_index)

        shift_labels = labels[..., 1:].contiguous()

        # Flatten the tokens
        logits = logits.view(-1, logits.size(-1))
        shift_labels = shift_labels.view(-1)

        # Enable model parallelism
        shift_labels = shift_labels.to(logits.device)

        count += (shift_labels != ignore_index).sum()
        loss += nn.functional.cross_entropy(
            logits, shift_labels, ignore_index=ignore_index, reduction="sum"
        )
    return loss / count


def batched_long_cross_entropy_attn_postfix_multi(
    inputs, target, model, tokenizer, config
):
    """
    long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
    """

    model.eval()

    # label segments contains the endpoints of the labeled sections
    # nolabel segments contains the encpoints of the unlabeled sections
    label_segments = []
    tmp_target = target.clone().squeeze(0)
    while True:
        cp = (tmp_target >= 0).nonzero()
        if cp.numel() == 0:
            break

        cp = cp[0].item()
        tmp_target = tmp_target[cp:]

        ep = (tmp_target < 0).nonzero()
        if ep.numel() == 0:
            break

        ep = ep[0].item()
        if len(label_segments) == 0:
            label_segments.append((cp, cp + ep))
        else:
            cp_old, ep_old = label_segments[-1]
            label_segments.append((cp + ep_old, cp + ep_old + ep))

        tmp_target = tmp_target[ep:]

    # PREPROCESS ====================================================
    # calculate ground truth long logits so that all models will calculate long-ppl on the
    # same tokens
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(
        inputs,
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits

    # calculate ground truth short outputs
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    short_outputs = model(
        inputs,
        # FIXME: padding need?
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    # END PREPROCESS ====================================================

    sp = 0  # start point
    past_key_values = None
    # recomputed_tokens = torch.Tensor()

    long_ce_loss, long_ce_count, ce_loss, ce_count = 0, 0, 0, 0
    for i, (cp, ep) in enumerate(label_segments):
        if i > 0:
            # set start point for iterations after 1
            sp = label_segments[i - 1][1]

        short_hidden = short_outputs.logits[:, sp:cp]

        context_inputs, context_target = inputs[:, sp:cp], target[:, sp:cp]
        gen_inputs, gen_target = inputs[:, cp:ep], target[:, cp:ep]
        long_logits_truth = long_fa2_hidden[:, cp:ep]

        # if i > 0:
        #     print(
        #         f"testing gen target: {cp=} {ep=} {target.size()=} {target[:, cp - 3:ep + 3]=}"
        #     )

        outputs = model(
            context_inputs,
            # attention_mask=(context_inputs.ne(tokenizer.pad_token_id)).to(
            #     context_inputs.dtype
            # ),
            labels=context_target,
            past_key_values=past_key_values,
            use_cache=True,
        )
        long_hidden = outputs.logits
        past_key_values = outputs.past_key_values

        idx_out = []
        for b in range(0, context_inputs.size(1), config.long_ce_block_size):
            long_logits = model.lm_head(
                long_hidden[:, b : b + config.long_ce_block_size]
            )
            short_logits = model.lm_head(
                short_hidden[:, b : b + config.long_ce_block_size]
            )
            labels = context_inputs.to(long_logits.device)[
                :, b : b + config.long_ce_block_size + 1
            ]

            out = long_ppl(
                model,
                long_logits,
                short_logits,
                labels,
                idx_select=True,
                alpha=config.long_ppl_alpha,
                beta=config.long_ppl_beta,
            )
            idx_out.append(out)

        idx = torch.cat(idx_out, dim=0)
        recompute_n = max(16, idx.size(0) // 20)

        idx = torch.argsort(idx, dim=-1, descending=True)[:recompute_n]
        # cp = cutoff_point.item()
        # idx = torch.cat((idx, torch.arange(cp - 256, cp, device=idx.device)))
        # idx, _ = torch.sort(torch.unique(idx), dim=-1)
        idx, _ = torch.sort(idx, dim=-1)
        # print(f"iteration: {i}: {idx=}")

        # keys_before_recompute = [v.clone() for v in past_key_values.key_cache]
        past_key_values.recompute_idx = idx + sp
        recompute_inputs = context_inputs[:, idx]
        # print(f"{recompute_inputs.size()=}")
        output = model(
            recompute_inputs,
            # attention_mask=(context_inputs.ne(tokenizer.pad_token_id)).to(
            #     context_inputs.dtype
            # ),
            past_key_values=past_key_values,
            use_cache=True,
        )

        # cache has been updated. delete recompute idx
        past_key_values = output.past_key_values
        delattr(past_key_values, "recompute_idx")

        # print(f"doing gen inputs: {gen_inputs=}")
        outputs = model(
            gen_inputs,
            # attention_mask=(gen_inputs.ne(tokenizer.pad_token_id)).to(gen_inputs.dtype),
            labels=gen_target,
            past_key_values=past_key_values,
            use_cache=True,
        )
        long_hidden = outputs.logits
        short_hidden = short_outputs.logits[:, cp:ep]
        past_key_values = outputs.past_key_values

        # print(f"doing final ppl: {long_hidden.size()=} {short_hidden.size()=}")
        for b in range(0, inputs.size(1), config.long_ce_block_size):
            long_logits = model.lm_head(
                long_hidden[:, b : b + config.long_ce_block_size]
            )
            short_logits = model.lm_head(
                short_hidden[:, b : b + config.long_ce_block_size]
            )
            gt_long_logits = model.lm_head(
                long_logits_truth[:, b : b + config.long_ce_block_size]
            )
            labels = gen_target.to(long_logits.device)[
                :, b : b + config.long_ce_block_size + 1
            ]

            lce, lce_cnt, ce, ce_cnt = long_ppl(
                model,
                long_logits,
                short_logits,
                labels,
                alpha=config.long_ppl_alpha,
                beta=config.long_ppl_beta,
                long_logits_truth=gt_long_logits,
                block_index=b,
            )
            long_ce_loss += lce
            long_ce_count += lce_cnt
            ce_loss += ce
            ce_count += ce_cnt

        # print(f"{long_ce_loss / long_ce_count=} {ce_loss / ce_count=}")

    return long_ce_loss / long_ce_count, ce_loss / ce_count


def batched_recompute_ondemand_multi(inputs, target, model, tokenizer, config):
    """
    long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
    """

    model.eval()

    # label segments contains the endpoints of the labeled sections
    # nolabel segments contains the encpoints of the unlabeled sections
    label_segments = []
    tmp_target = target.clone().squeeze(0)
    while True:
        cp = (tmp_target >= 0).nonzero()
        if cp.numel() == 0:
            break

        cp = cp[0].item()
        tmp_target = tmp_target[cp:]

        ep = (tmp_target < 0).nonzero()
        if ep.numel() == 0:
            break

        ep = ep[0].item()
        if len(label_segments) == 0:
            label_segments.append((cp, cp + ep))
        else:
            cp_old, ep_old = label_segments[-1]
            label_segments.append((cp + ep_old, cp + ep_old + ep))

        tmp_target = tmp_target[ep:]

    # PREPROCESS ====================================================
    # calculate ground truth long logits so that all models will calculate long-ppl on the
    # same tokens
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(
        inputs,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits

    # calculate ground truth short outputs
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    short_outputs = model(
        inputs,
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    # END PREPROCESS ====================================================

    sp = 0  # start point
    past_key_values = None
    # recomputed_tokens = torch.Tensor()

    long_ce_loss, long_ce_count, ce_loss, ce_count = 0, 0, 0, 0
    for i, (cp, ep) in enumerate(label_segments):
        if i > 0:
            # set start point for iterations after 1
            sp = label_segments[i - 1][1]

        short_hidden = short_outputs.logits[:, sp:cp]

        context_inputs, context_target = inputs[:, sp:cp], target[:, sp:cp]
        gen_inputs, gen_target = inputs[:, cp:ep], target[:, cp:ep]
        long_logits_truth = long_fa2_hidden[:, cp:ep]

        # if i > 0:
        #     print(
        #         f"testing gen target: {cp=} {ep=} {target.size()=} {target[:, cp - 3:ep + 3]=}"
        #     )

        # hip forward
        if i == 0:
            outputs = model(
                context_inputs,
                labels=context_target,
                past_key_values=past_key_values,
                use_cache=True,
            )
            long_hidden = outputs.logits
            past_key_values = outputs.past_key_values

            # do dense last_n and set refine idx
            last_n = min(128, long_hidden.size(1))
            refine_idx = torch.arange(long_hidden.size(1), device=long_hidden.device)[
                -last_n:
            ]
            previous_attn_backend = model.config._attn_implementation

            model.config._attn_implementation = "recompute_dense"
            past_key_values.recompute_idx = refine_idx
            recompute_dense_inputs = context_inputs[:, refine_idx]
            past_key_values.scores = []
            recompute_outputs = model(
                recompute_dense_inputs,
                past_key_values=past_key_values,
                use_cache=True,
            )
            model.config._attn_implementation = previous_attn_backend
            past_key_values = recompute_outputs.past_key_values
            scores = past_key_values.scores
            delattr(past_key_values, "recompute_idx")
            delattr(past_key_values, "scores")
        else:
            previous_attn_backend = model.config._attn_implementation
            model.config._attn_implementation = "dense"
            past_key_values.scores = []
            recompute_outputs = model(
                context_inputs,
                past_key_values=past_key_values,
                use_cache=True,
            )
            model.config._attn_implementation = previous_attn_backend
            past_key_values = recompute_outputs.past_key_values
            scores = past_key_values.scores

            # for score in scores:
            #     print(f"{score.size()=}")

            delattr(past_key_values, "scores")

        scores = torch.stack(scores).squeeze(dim=1)

        refine_idx = torch.argsort(
            scores[:, 128:][:, : -context_inputs.size(1)], dim=1, descending=True
        )
        refine_idx = refine_idx[:, : config.recompute_n] + 128
        refine_idx = refine_idx.view(-1).unique()

        # cp = cutoff_point.item()
        # idx = torch.cat((idx, torch.arange(cp - 256, cp, device=idx.device)))
        # idx, _ = torch.sort(torch.unique(idx), dim=-1)
        # print(f"iteration: {i}: {idx=}")

        # set recompute forward and do recompute forward
        if refine_idx.size(0) > 0:
            # keys_before_recompute = [v.clone() for v in past_key_values.key_cache]
            previous_attn_backend = model.config._attn_implementation
            model.config._attn_implementation = "recompute"
            past_key_values.recompute_idx = refine_idx
            recompute_inputs = inputs[:, refine_idx]
            recompute_outputs = model(
                recompute_inputs,
                past_key_values=past_key_values,
                use_cache=True,
            )
            model.config._attn_implementation = previous_attn_backend
            past_key_values = recompute_outputs.past_key_values
            delattr(past_key_values, "recompute_idx")

        # print(f"doing gen inputs: {gen_inputs=}")
        previous_attn_backend = model.config._attn_implementation
        model.config._attn_implementation = "sdpa_rectangle"
        outputs = model(
            gen_inputs,
            # attention_mask=(gen_inputs.ne(tokenizer.pad_token_id)).to(gen_inputs.dtype),
            labels=gen_target,
            past_key_values=past_key_values,
            use_cache=True,
        )
        model.config._attn_implementation = previous_attn_backend
        long_hidden = outputs.logits
        short_hidden = short_outputs.logits[:, cp:ep]
        past_key_values = outputs.past_key_values

        # print(f"doing final ppl: {long_hidden.size()=} {short_hidden.size()=}")
        for b in range(0, inputs.size(1), config.long_ce_block_size):
            long_logits = model.lm_head(
                long_hidden[:, b : b + config.long_ce_block_size]
            )
            short_logits = model.lm_head(
                short_hidden[:, b : b + config.long_ce_block_size]
            )
            gt_long_logits = model.lm_head(
                long_logits_truth[:, b : b + config.long_ce_block_size]
            )
            labels = gen_target.to(long_logits.device)[
                :, b : b + config.long_ce_block_size + 1
            ]

            lce, lce_cnt, ce, ce_cnt = long_ppl(
                model,
                long_logits,
                short_logits,
                labels,
                alpha=config.long_ppl_alpha,
                beta=config.long_ppl_beta,
                long_logits_truth=gt_long_logits,
                block_index=b,
            )
            long_ce_loss += lce
            long_ce_count += lce_cnt
            ce_loss += ce
            ce_count += ce_cnt

        # print(
        #     f"{np.exp((long_ce_loss / long_ce_count).cpu())=} {np.exp(( ce_loss / ce_count ).cpu())=}")

    return long_ce_loss / long_ce_count, ce_loss / ce_count


def batched_recompute_ondemand(inputs, target, model, tokenizer, config):
    """
    1. hip forward on X[:-last_n].
    2. dense forward on X[-last_n:]; sum attention scores
    3. select top-k tokens from attention scores for recompute
    4. recompute forward.
    5. gen forward
    """

    model.eval()
    last_n = 128

    print(f"{inputs.size()=} {target.size()=} {(target >= 0).sum()=}")
    cutoff_point = (target >= 0).squeeze(0).nonzero()[0].item()
    context_inputs, context_target = inputs[:, :cutoff_point], target[:, :cutoff_point]
    gen_inputs, gen_target = inputs[:, cutoff_point:], target[:, cutoff_point:]
    print(f"{context_inputs.size()=} {inputs.size()=}")

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    outputs = model(context_inputs, labels=context_target, use_cache=True)
    model.config._attn_implementation = previous_attn_backend
    long_hidden = outputs.logits
    past_key_values = outputs.past_key_values

    # calculate ground truth long logits so that all models will calculate long-ppl on the
    # same tokens
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(inputs, use_cache=True)
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits

    # window forward
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    short_outputs = model(
        inputs,
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    short_hidden = short_outputs.logits[:, : context_inputs.size(1)]

    # dense forward on last n to collect scores
    refine_idx = torch.arange(long_hidden.size(1), device=long_hidden.device)[-last_n:]
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "recompute_dense"
    past_key_values.recompute_idx = refine_idx
    recompute_dense_inputs = context_inputs[:, refine_idx]
    recompute_outputs = model(
        recompute_dense_inputs,
        past_key_values=past_key_values,
        use_cache=True,
    )
    model.config._attn_implementation = previous_attn_backend
    past_key_values = recompute_outputs.past_key_values
    scores = past_key_values.scores
    delattr(past_key_values, "recompute_idx")
    delattr(past_key_values, "scores")

    print(f"scores: {scores.size()=}")
    refine_idx = (
        torch.argsort(scores.squeeze(0)[128:], descending=True)[: config.recompute_n]
        + 128
    )

    if refine_idx.size(0) > 0:
        # keys_before_recompute = [v.clone() for v in past_key_values.key_cache]
        previous_attn_backend = model.config._attn_implementation
        model.config._attn_implementation = "recompute"
        past_key_values.recompute_idx = refine_idx
        recompute_inputs = context_inputs[:, refine_idx]
        recompute_outputs = model(
            recompute_inputs,
            past_key_values=past_key_values,
            use_cache=True,
        )
        model.config._attn_implementation = previous_attn_backend
        past_key_values = recompute_outputs.past_key_values
        delattr(past_key_values, "recompute_idx")

    # # NOTE uncomment following lines to recompute everything
    # past_key_values.recompute_idx = torch.arange(
    #     sorted_idx.size(-1), device=sorted_idx.device
    # )
    # previous_attn_backend = model.config._attn_implementation
    # model.config._attn_implementation = "recompute"
    # recompute_inputs = context_inputs
    # print(recompute_inputs.shape)
    # recompute_outputs = model(
    #     recompute_inputs,
    #     # cache_position=other_idx,
    #     # position_ids=other_idx.unsqueeze(0),
    #     past_key_values=past_key_values,
    #     use_cache=True,
    # )
    # # cache has been updated. delete recompute idx
    # past_key_values = recompute_outputs.past_key_values
    # model.config._attn_implementation = previous_attn_backend
    # delattr(past_key_values, "recompute_idx")
    # # STOP COMMENT SECTION =================================

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "sdpa_rectangle"
    cache_location = torch.arange(
        context_inputs.shape[1], inputs.shape[1], device=gen_inputs.device
    )
    gen_outputs = model(
        gen_inputs,
        cache_position=cache_location,
        position_ids=cache_location.unsqueeze(0),
        labels=gen_target,
        past_key_values=past_key_values,
        use_cache=True,
    )
    model.config._attn_implementation = previous_attn_backend

    model.config._attn_implementation = previous_attn_backend
    gen_hidden = gen_outputs.logits
    short_hidden = short_outputs.logits[:, context_inputs.shape[1] :]
    long_fa2_hidden = long_fa2_hidden[:, context_inputs.shape[1] :]

    ce_count, long_ce_count = 0, 0
    ce_loss, long_ce_loss = 0, 0
    for b in range(0, gen_hidden.shape[1], config.long_ce_block_size):
        range_start = b
        range_end = range_start + config.long_ce_block_size
        gen_logits = model.lm_head(gen_hidden[:, range_start:range_end])
        long_fa2_logits = model.lm_head(long_fa2_hidden[:, range_start:range_end])

        short_logits = model.lm_head(short_hidden[:, range_start:range_end])
        labels = gen_target.to(gen_logits.device)[:, range_start : range_end + 1]

        lce, lce_cnt, ce, ce_cnt = long_ppl(
            model,
            gen_logits,
            short_logits,
            labels,
            alpha=config.long_ppl_alpha,
            beta=config.long_ppl_beta,
            long_logits_truth=long_fa2_logits,
            block_index=b,
        )
        long_ce_loss += lce
        long_ce_count += lce_cnt
        ce_loss += ce
        ce_count += ce_cnt

    print(f"in postfix: {long_ce_loss=} {long_ce_count=}")
    return long_ce_loss / long_ce_count, ce_loss / ce_count


def batched_long_cross_entropy_attn_postfix(inputs, target, model, tokenizer, config):
    """
    long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
    """

    model.eval()

    print(f"{inputs.size()=} {target.size()=} {(target >= 0).sum()=}")
    cutoff_point = (target >= 0).squeeze(0).nonzero()[0].item()
    context_inputs, context_target = inputs[:, :cutoff_point], target[:, :cutoff_point]
    gen_inputs, gen_target = inputs[:, cutoff_point:], target[:, cutoff_point:]
    # print(f"{context_inputs.size()=} {inputs.size()=}")
    # print('inp', tokenizer.decode(context_inputs[0, -100:]))
    # print('gen', tokenizer.decode(gen_inputs[0, :]))

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    # outputs = model(
    #     context_inputs,
    #     sliding_window=config.long_ce_k * 8,
    #     use_cache=True,
    # )
    outputs = model(
        context_inputs, 
        labels=context_target, 
        use_cache=True
    )
    model.config._attn_implementation = previous_attn_backend
    long_hidden = outputs.logits
    past_key_values = outputs.past_key_values

    # calculate ground truth long logits so that all models will calculate long-ppl on the
    # same tokens
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(inputs, use_cache=True)
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits

    # window forward
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    short_outputs = model(
        inputs,
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    short_hidden = short_outputs.logits[:, : context_inputs.size(1)]

    # print(f"{long_hidden.size()=} {short_hidden.size()=}")
    # calculate indices for dense recomputation of specific tokens.
    long_ppl_scores = []
    for b in range(0, context_inputs.size(1), config.long_ce_block_size):
        if os.environ.get("ORACLE_RECOMPUTE", "0") == "1":
            _long_fa2_hidden = long_fa2_hidden[:, : short_hidden.size(1)]
            long_logits = model.lm_head(
                _long_fa2_hidden[:, b : b + config.long_ce_block_size]
            )
        else:
            long_logits = model.lm_head(
                long_hidden[:, b : b + config.long_ce_block_size]
            )

        short_logits = model.lm_head(short_hidden[:, b : b + config.long_ce_block_size])
        labels = context_inputs.to(long_logits.device)[
            :, b : b + config.long_ce_block_size + 1
        ]

        out = long_ppl(
            model,
            long_logits,
            short_logits,
            labels,
            idx_select=True,
            alpha=config.long_ppl_alpha,
            beta=config.long_ppl_beta,
        )
        long_ppl_scores.append(out)

    long_ppl_scores = torch.cat(long_ppl_scores, dim=0)
    
    # TODO: refine 프로세스 추가, 리팩토링. RefineableDecoding .prefill .extend
    
    # NOTE: exclude short context
    # print(long_ppl_scores, long_ppl_scores.shape)
    long_ppl_scores[...,:8192] = long_ppl_scores.view(-1).amin()
    
    # plt.clf()
    # plt.plot(long_ppl_scores.cpu().float().numpy())
    # plt.savefig('dummy_ppl.png')
    # input('>>>')
    
    # NOTE: to test recompute schedule
    # refine_schedule = [config.recompute_n]
    refine_schedule = [config.recompute_n, config.recompute_n // 2]
    refine_schedule = []
    # refine_schedule = [1024, 256, 128, 128]
    # refine_schedule = [config.recompute_n, 128]
    # refine_schedule = [config.recompute_n // 16, config.recompute_n // 8, config.recompute_n // 4, config.recompute_n // 2]
    accummulated_refine = False
    
    if len(refine_schedule) > 0:
        refine_history = []
        
        sorted_idx = torch.argsort(long_ppl_scores, dim=-1, descending=True)
        refine_idx, other_idx = (
            sorted_idx[: refine_schedule[0]],
            sorted_idx[refine_schedule[0] :],
        )
        refine_idx = torch.sort(refine_idx, dim=-1).values
        other_idx = torch.sort(other_idx, dim=-1).values

        if os.environ.get("RANDOM_RECOMPUTE", "0") == "1":
            refine_idx = torch.randperm(sorted_idx.size(0), device=sorted_idx.device)[
                : config.recompute_n
            ]
            refine_idx = torch.sort(refine_idx, dim=-1).values

        # print(f"{refine_idx.size()=} {other_idx.size()=}")
        # print(f"{refine_idx=}")

        # keys_before_recompute = [v.clone() for v in past_key_values.key_cache]
        
        for i_refine_step, refine_n in enumerate(refine_schedule):
            need_next_recompute = i_refine_step < (len(refine_schedule) - 1)
            require_scores = need_next_recompute
            
            previous_attn_backend = model.config._attn_implementation
            model.config._attn_implementation = "recompute"
            assert refine_idx.shape[0] == refine_n
            if not accummulated_refine:
                current_refine_idx = refine_idx
            else:
                current_refine_idx = torch.cat([refine_idx] + refine_history).sort().values
            # print(f'refine {current_refine_idx.shape=}')
            past_key_values.recompute_idx = current_refine_idx
            recompute_inputs = context_inputs[:, current_refine_idx]
            recompute_outputs = model(
                recompute_inputs,
                past_key_values=past_key_values,
                use_cache=True,
                output_attentions=require_scores,
            )
            model.config._attn_implementation = previous_attn_backend
            past_key_values = recompute_outputs.past_key_values
            delattr(past_key_values, "recompute_idx")
            
            reduce_method = 'mean'
            
            if need_next_recompute:
                if reduce_method == 'mean':
                    reduced_scores_weight = 0
                    reduced_scores = 0
                elif reduce_method == 'max':
                    reduced_scores_weight = None
                    reduced_scores = None
                else:
                    raise Exception()
                
                for _attention in tqdm.tqdm(
                    recompute_outputs.attentions,
                    desc='sample importance', 
                    dynamic_ncols=True, 
                    leave=False
                ):
                    chunk_size = 128
                    for i_tdst in range(0, _attention.shape[2], chunk_size):
                        attention = _attention[:, :, i_tdst:i_tdst+chunk_size].to(recompute_outputs.logits.device)
                        sw_size = 1024
                        sink_size = 1024
                        mask = (
                            (current_refine_idx[i_tdst: i_tdst+chunk_size].to(attention.device) - sw_size)[None, None, :, None] 
                            >= torch.arange(0, attention.shape[-1], device=attention.device)[None, None, None, :]
                        )
                        
                        attention = torch.where(mask, attention, float('-inf'))
                        attention[..., :sink_size] = float('-inf') # short context do not need recompute
                        
                        if reduce_method == 'mean':
                            probs = attention.softmax(dim=-1)
                            probs = probs\
                                .float()\
                                .sum(dim=2, keepdim=True)
                            mask = (attention > -1e4).int().sum(dim=2, keepdim=True)
                            reduced_scores += probs.to(recompute_outputs.logits.device)
                            reduced_scores_weight += mask.to(recompute_outputs.logits.device)
                        elif reduce_method == 'max':
                            probs = attention.softmax(dim=-1)
                            probs = probs\
                                .amax(dim=2, keepdim=True)\
                                .to(recompute_outputs.logits.device)
                            if reduced_scores is None:
                                reduced_scores = probs
                            reduced_scores = torch.maximum(probs, reduced_scores)
                            reduced_scores_weight = 1
                        else:
                            raise Exception()
                reduced_scores = reduced_scores / ((reduced_scores_weight.float() if isinstance(reduced_scores_weight, torch.Tensor) else 1) + 1e-20)
                if reduce_method == 'mean':
                    reduced_scores = reduced_scores.mean(dim=1,keepdim=True)
                elif reduce_method == 'max':
                    reduced_scores = reduced_scores.amax(dim=1,keepdim=True)
                else:
                    raise Exception()
                assert reduced_scores.shape[:3] == (1,1,1,)
                reduced_scores = reduced_scores.squeeze(0).squeeze(0).squeeze(0)
                min_reduced_scores = reduced_scores.view(-1).amin()
                if accummulated_refine:
                    reduced_scores.index_fill_(
                        dim=0, 
                        index=torch.cat([refine_idx] + refine_history),
                        value=min_reduced_scores
                    )
                
                # print(reduced_scores, reduced_scores.shape)
                
                # plt.clf()
                # plt.plot(reduced_scores.cpu().float().numpy())
                # plt.savefig('dummy.png')
                # input('>>>')
                
                next_refine_idx = reduced_scores\
                    .topk(k=refine_schedule[i_refine_step + 1], dim=-1)\
                    .indices
                refine_history.append(refine_idx)
                refine_idx = next_refine_idx
    
    # # NOTE uncomment following lines to recompute everything
    # past_key_values.recompute_idx = torch.arange(
    #     sorted_idx.size(-1), device=sorted_idx.device
    # )
    # previous_attn_backend = model.config._attn_implementation
    # model.config._attn_implementation = "recompute"
    # recompute_inputs = context_inputs
    # print(recompute_inputs.shape)
    # recompute_outputs = model(
    #     recompute_inputs,
    #     # cache_position=other_idx,
    #     # position_ids=other_idx.unsqueeze(0),
    #     past_key_values=past_key_values,
    #     use_cache=True,
    # )
    # # cache has been updated. delete recompute idx
    # past_key_values = recompute_outputs.past_key_values
    # model.config._attn_implementation = previous_attn_backend
    # delattr(past_key_values, "recompute_idx")
    # # STOP COMMENT SECTION =================================

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    cache_location = torch.arange(
        context_inputs.shape[1], inputs.shape[1], device=gen_inputs.device
    )
    gen_outputs = model(
        gen_inputs,
        cache_position=cache_location,
        position_ids=cache_location.unsqueeze(0),
        labels=gen_target,
        past_key_values=past_key_values,
        use_cache=True,
    )
    model.config._attn_implementation = previous_attn_backend

    model.config._attn_implementation = previous_attn_backend
    gen_hidden = gen_outputs.logits
    short_hidden = short_outputs.logits[:, context_inputs.shape[1] :]
    long_fa2_hidden = long_fa2_hidden[:, context_inputs.shape[1] :]

    ce_count, long_ce_count = 0, 0
    ce_loss, long_ce_loss = 0, 0
    for b in range(0, gen_hidden.shape[1], config.long_ce_block_size):
        range_start = b
        range_end = range_start + config.long_ce_block_size
        gen_logits = model.lm_head(gen_hidden[:, range_start:range_end])
        long_fa2_logits = model.lm_head(long_fa2_hidden[:, range_start:range_end])

        short_logits = model.lm_head(short_hidden[:, range_start:range_end])
        labels = gen_target.to(gen_logits.device)[:, range_start : range_end + 1]

        lce, lce_cnt, ce, ce_cnt = long_ppl(
            model,
            gen_logits,
            short_logits,
            labels,
            alpha=config.long_ppl_alpha,
            beta=config.long_ppl_beta,
            long_logits_truth=long_fa2_logits,
            block_index=b,
        )
        long_ce_loss += lce
        long_ce_count += lce_cnt
        ce_loss += ce
        ce_count += ce_cnt
    
    return long_ce_loss / long_ce_count, ce_loss / ce_count

def batched_long_cross_entropy(inputs, target, model, tokenizer, config):
    """
    long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
    """

    model.eval()

    cutoff_point = (target >= 0).squeeze(0).nonzero()[0].item()
    # print(f"{cutoff_point=} {inputs.size()=}")

    dense_decode = True
    if dense_decode:
        assert inputs.ndim == 2
        assert target.ndim == 2
        
        outputs = model(
            inputs[:, :cutoff_point],
            labels=target[:, :cutoff_point],
            use_cache=True,
        )
        long_hidden = outputs.logits
        past_key_values = outputs.past_key_values
        
        previous_attn_backend = model.config._attn_implementation
        model.config._attn_implementation = "flash_attention_2"
        cache_location = torch.arange(
            cutoff_point, inputs.shape[1], device=inputs.device
        )
        outputs = model(
            inputs[:, cutoff_point:],
            cache_position=cache_location,
            position_ids=cache_location.unsqueeze(0),
            labels=target[:, cutoff_point:],
            past_key_values=past_key_values,
            use_cache=True,
        )
        model.config._attn_implementation = previous_attn_backend
        
        long_hidden = torch.cat([long_hidden, outputs.logits], dim=1)
        past_key_values = outputs.past_key_values
    else:
        outputs = model(
            inputs,
            labels=target,
            use_cache=False,
        )
        long_hidden = outputs.logits
        past_key_values = outputs.past_key_values

    # this block is always no grad no matter what
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    short_outputs = model(
        inputs,
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    short_hidden = short_outputs.logits

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(
        inputs,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits
    
    # long_hidden = long_fa2_hidden

    count, loss = 0, 0
    ce_count, ce_loss = 0, 0
    long_hidden = long_hidden[:, cutoff_point:]
    long_fa2_hidden = long_fa2_hidden[:, cutoff_point:]
    short_hidden = short_hidden[:, cutoff_point:]
    target = target[:, cutoff_point:]

    for b in range(0, long_hidden.size(1), config.long_ce_block_size):
        long_logits = model.lm_head(long_hidden[:, b : b + config.long_ce_block_size])
        long_fa2_logits = model.lm_head(
            long_fa2_hidden[:, b : b + config.long_ce_block_size]
        )
        short_logits = model.lm_head(short_hidden[:, b : b + config.long_ce_block_size])
        labels = target.to(long_logits.device)[:, b : b + config.long_ce_block_size + 1]

        l, c, ce_l, ce_c = long_ppl(
            model,
            long_logits,
            short_logits,
            labels,
            alpha=config.long_ppl_alpha,
            beta=config.long_ppl_beta,
            long_logits_truth=long_fa2_logits,
            block_index=b,
        )
        loss += l
        count += c
        ce_loss += ce_l
        ce_count += ce_c

    # print(f"[{model.config._attn_implementation}] longPPL={torch.exp(loss/count)} PPL={torch.exp(ce_loss/ce_count)}")
    return loss / count, ce_loss / ce_count

def batched_long_cross_entropy_hip(inputs, target, model, tokenizer, config):
    """
    long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
    """

    model.eval()

    cutoff_point = (target >= 0).squeeze(0).nonzero()[0].item()
    print(f"{cutoff_point=} {inputs.size()=}")

    outputs = model(
        inputs,
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        labels=target,
        use_cache=False,
    )
    long_hidden = outputs.logits
    past_key_values = outputs.past_key_values

    # this block is always no grad no matter what
    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "hip_attention"
    modify_hip_args(model, True, config)
    short_outputs = model(
        inputs,
        # FIXME: padding needed?
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        sliding_window=config.long_ce_k,
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    modify_hip_args(model, False, config)
    short_hidden = short_outputs.logits

    previous_attn_backend = model.config._attn_implementation
    model.config._attn_implementation = "flash_attention_2"
    long_outputs_fa2 = model(
        inputs,
        # attention_mask=(inputs.ne(tokenizer.pad_token_id)).to(inputs.dtype),
        use_cache=False,
    )
    model.config._attn_implementation = previous_attn_backend
    long_fa2_hidden = long_outputs_fa2.logits

    count, loss = 0, 0
    long_hidden = long_hidden[:, cutoff_point:]
    long_fa2_hidden = long_fa2_hidden[:, cutoff_point:]
    short_hidden = short_hidden[:, cutoff_point:]
    target = target[:, cutoff_point:]

    for b in range(0, long_hidden.size(1), config.long_ce_block_size):
        long_logits = model.lm_head(long_hidden[:, b : b + config.long_ce_block_size])
        long_fa2_logits = model.lm_head(
            long_fa2_hidden[:, b : b + config.long_ce_block_size]
        )
        short_logits = model.lm_head(short_hidden[:, b : b + config.long_ce_block_size])
        labels = target.to(long_logits.device)[:, b : b + config.long_ce_block_size + 1]

        l, c, _, _ = long_ppl(
            model,
            long_logits,
            short_logits,
            labels,
            alpha=config.long_ppl_alpha,
            beta=config.long_ppl_beta,
            long_logits_truth=long_fa2_logits,
            block_index=b,
        )
        loss += l
        count += c

    print(f"in long cross entropy: {loss=} {count=}")
    return loss / count


def get_hf_dataset(ds):
    def gen():
        for idx in range(len(ds)):
            inputs, targets = ds[idx]
            yield {"input_ids": inputs, "labels": targets}

    return datasets.IterableDataset.from_generator(gen)


def parse_args():
    parser = ArgumentParser(Config)
    train_config = parser.parse_args()
    log.info(train_config)
    return train_config


def get_hip_config(config: Config, layer_idx: int):
    preset_name = os.environ.get("PRESET", "default")
    if preset_name == "default":
        stages = [
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=256,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=32,
                stage_k=32768,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=1,
                stage_chunk_size=8,
                stage_k=8192,
                stage_stride=1,
            ),
        ]

        args = HiPAttentionArgs(
            sliding_window_size=1024,
            sink_token_size=128,
            using_extend=False,
            need_apply_rope=False,
            second_stage_k=2048,
            stages=stages,
            model_context_length=131072,
            # scan_extend_backend="relative",
            scan_extend_backend=("streaming" if layer_idx < 3 else "relative"),
            sa_extend_backend="streaming",
            block_sparse_block_size_q=stages[-1].stage_block_size_q,
            enable_hip_tune=False,
            disable_gate_prob=True,
            fallback_to_v12=True,
        )
    else:
        raise NotImplementedError(f"{preset_name=} not implemented")
    # print(f"hip args: {args}")
    return args


def modify_hip_args(model, turn_on, config):
    layer_idx = 0
    for m in model.modules():
        if isinstance(m, LlamaAttention):
            hip_attn_config = get_hip_config(config, layer_idx)
            if turn_on:
                hip_attn_config.using_extend = True
                hip_attn_config.need_apply_rope = True
            else:
                hip_attn_config.using_extend = False
                hip_attn_config.need_apply_rope = False

            m.hip_attn_args = hip_attn_config
            m.attention_method = config.method

            layer_idx += 1


def init_model(config: Config):
    device = "cuda:0"

    ALL_ATTENTION_FUNCTIONS.update({"hip_attention": (lambda x: x)})
    ALL_ATTENTION_FUNCTIONS.update({"sdpa_rectangle": (lambda x: x)})
    ALL_ATTENTION_FUNCTIONS.update({"recompute": (lambda x: x)})
    ALL_ATTENTION_FUNCTIONS.update({"recompute_dense": (lambda x: x)})

    tokenizer = transformers.AutoTokenizer.from_pretrained(config.model)

    attn_implementation = os.environ.get("ATTN_IMPLEMENTATION", "hip_attention")

    if "minference" in attn_implementation:
        from minference import MInference

        model = transformers.AutoModelForCausalLM.from_pretrained(config.model)

        minference_patch = MInference(attn_implementation, config.model)
        model = minference_patch(model)
    else:
        model_config = LlamaConfig.from_pretrained(
            config.model,
            attn_implementation=attn_implementation,
            torch_dtype=torch.bfloat16,
        )

        model_config.pooler_method = config.pooler_method
        model_config.pooler_config = config.pooler_config
        model = LlamaForCausalLM.from_pretrained(
            config.model,
            config=model_config,
            torch_dtype=torch.bfloat16,
        )

        layer_idx = 0
        for m in model.modules():
            if isinstance(m, LlamaAttention):
                m.args = config

                hip_attn_config = get_hip_config(config, layer_idx)
                if attn_implementation != "hip_attention":
                    hip_attn_config.using_extend = False
                    hip_attn_config.need_apply_rope = False
                m.hip_attn_args = hip_attn_config
                m.attention_method = config.method

                layer_idx += 1

    return model, tokenizer


def main(config: Config):
    if os.environ.get("DETECT_ANOMALY", "0") == "1":
        torch.autograd.set_detect_anomaly(True)

    # torch.set_float32_matmul_precision('high')
    torch.backends.cudnn.benchmark = True

    os.environ["WANDB_PROJECT"] = "quick_extend"

    if os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1":
        log.info(
            "WARNING: CUDA_LAUNCH_BLOCKING is set to 1, this will slow down the training.",
        )

    filename = f'{config.model.replace("/", "_")}-{config.name}-{config.dataset}-{config.seq_len}-{config.get_hash()}'
    config.model_checkpoint_dir = config.model_checkpoint_dir + "/" + filename

    model, tokenizer = init_model(config=config)
    tokenizer.pad_token = tokenizer.eos_token

    def train_val_split(dataset):
        train_idx, val_idx = train_test_split(
            list(range(len(dataset))), test_size=config.val_split
        )
        train_data = Subset(dataset, train_idx)
        valid_data = Subset(dataset, val_idx)
        return train_data, valid_data

    test_data = None
    if config.dataset == "owt":
        dataset = OpenWebTextDataset(
            tokenizer=tokenizer, stride=config.seq_len, offset_labels=False
        )
        train_data, valid_data = train_val_split(dataset)
    elif config.dataset == "pg19":
        dataset = PG19Dataset(
            tokenizer=tokenizer, stride=config.seq_len, offset_labels=False
        )
        train_data, valid_data = train_val_split(dataset)
    elif config.dataset == "pg19-longqa":
        path = os.getenv("LONGQA_PATH", "/data/anno_0/pg19-hierarchical-qa/")
        # train_data = PG19LongQA(tokenizer, path, split="train")
        # valid_data = PG19LongQA(tokenizer, path, split="validation")
        test_data = PG19LongQA(tokenizer, path, split="test")
        test_data = get_hf_dataset(test_data)
    else:
        raise ValueError(f"Unknown dataset {config.dataset}")

    model = model.cuda()
    metrics = evaluate(config, model, tokenizer, test_data, metric_key_prefix="test")
    print(f"test metrics: {metrics=}")


def run():
    seed()
    main(parse_args())


def seed(seed=42):
    import random

    import numpy as np

    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    os.environ["PYTHONHASHSEED"] = str(seed)


if __name__ == "__main__":
    run()

    d = {
        "flash": {"test/long-ppl": 8.336586316839997, "test/ppl": 4.9651326269375256},
        "hip, recompute all": {
            "test/long-ppl": 8.3381241134351,
            "test/ppl": 4.9650973602876,
        },
        "hip, recompute none": {
            "test/long-ppl": 9.428468762538161,
            "test/ppl": 5.065957519243355,
        },
        "hip, recompute 128": {
            "test/long-ppl": 9.416601027239132,
            "test/ppl": 5.065079511366147,
        },
        "hip, recompute 256": {
            "test/long-ppl": 9.411448079138468,
            "test/ppl": 5.064550002472922,
        },
        "hip, recompute 512": {
            "test/long-ppl": 9.402586559659069,
            "test/ppl": 5.063155803316984,
        },
        "hip, recompute 1024": {
            "test/long-ppl": 9.38693256540835,
            "test/ppl": 5.061209295450554,
        },
        "hip, recompute 2048": {
            "test/long-ppl": 9.357628645193147,
            "test/ppl": 5.058015400464616,
        },
        "hip, recompute 4096": {
            "test/long-ppl": 9.31266421557773,
            "test/ppl": 5.053256749578179,
        },
        "hip, recompute 8192": {
            "test/long-ppl": 9.227450673465114,
            "test/ppl": 5.043470289291057,
        },
        "hip, recompute 16384": {
            "test/long-ppl": 9.090939079338542,
            "test/ppl": 5.0287048305332345,
        },
        "hip, recompute 32768": {
            "test/long-ppl": 8.818453934814176,
            "test/ppl": 5.000267732706241,
        },
        "hip, recompute 65536": {
            "test/long-ppl": 8.41727128371619,
            "test/ppl": 4.968472206063697,
        },
    }
