from pathlib import Path
import sys

import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from get_prompt import get_long_prompt
from tap import Tap

sys.path.append('..')
from utils import get_non_embed_param_count, get_param_count


class Args(Tap):
    model_name: str = ''
    pretrained_path: str = '/Users/donny/donny/research/ckpts/rwkv6-world-1.6b'
    tok_path: str | None = None
    prompt_name: str = 'nextlines'
    device: str = 'cuda'
    max_len: int = 128
    overwrite: int = 0
    train_len: int = 8
    xmax: int = 32
    xlog: int = 0
    ylog: int = 1

    # The input is chunked before feeding to the model to avoid OOM.
    chunk_size: int = 128

    # How many tokens to advance the window every time.
    step_size: int = 1

    # The loss is averaged over a window of `bucket_size` tokens.
    bucket_size: int = 256

    # Whether to use sliding window, -1 means no sliding window.
    window_size: int = -1

    verbose: int = 0


def get_model(pretrained_path, tok_path=None, device='cuda', dtype=torch.float16, model_name: str = ''):
    if tok_path is None:
        tok_path = pretrained_path
    if model_name == 'mamba2' or "mamba2" in str(pretrained_path):
        from modeling.mamba2.modeling_mamba2_dao import Mamba2ForCausalLM
        print(f"Loading tokenizer from {tok_path}")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print(f"Loading model from {pretrained_path}...")
        model = Mamba2ForCausalLM.from_pretrained(
            pretrained_path,
            device=device,
            dtype=dtype,
        )
    elif 'rwkv6' in str(pretrained_path):
        print(f"Loading tokenizer from {tok_path}")
        tokenizer = AutoTokenizer.from_pretrained(pretrained_path, trust_remote_code=True)
        print(f"Loading model from {pretrained_path}...")
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_path,
            trust_remote_code=True,
        ).to(device=device, dtype=dtype)
    else:
        raise ValueError(f"Unknown model type: {pretrained_path}")
    return model, tokenizer


def compute_per_token_loss(args: Args, model, tokenizer, prompt: str):
    inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=args.max_len * 1024)
    input_ids: Tensor = inputs.input_ids.to(device=args.device)
    attention_mask: Tensor = inputs.attention_mask.to(device=args.device)
    print("Input ids shape:", input_ids.shape)
    assert args.window_size % args.step_size == 0

    # Run one forward pass and compute per token loss
    with torch.no_grad():
        loss_fn = nn.CrossEntropyLoss(reduction='none')
        if 'mamba2' in args.pretrained_path:
            all_loss = []

            lo = 0
            hi = args.window_size

            first_input_ids = input_ids[:, lo:hi]  # (bsz, window_size)
            first_attention_mask = attention_mask[:, lo:hi]  # (bsz, window_size)

            # Run one forward pass to get the initial states
            first_output = model(
                input_ids=first_input_ids,
                attention_mask=first_attention_mask,
                return_comps=True,
            )
            first_comps = first_output.comps
            window_dt = torch.stack([layer_comps['dt'] for layer_comps in first_comps], dim=0).to(dtype=torch.float32)  # (L, bsz, T, nheads)
            window_dt_sum = window_dt.sum(dim=2)  # (L, bsz, nheads)

            right_states = first_output.states
            left_states = None
            logits = first_output.logits

            # Collect per token loss
            labels = first_input_ids[:, 1:]  # (bsz, W - 1)
            preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, W - 1)
            loss = loss_fn(preds, labels)  # (bsz, W - 1)
            all_loss.append(loss)

            while hi < input_ids.shape[1]:
                print(f"Processing window: [{lo}:{hi}]")
                # chunk_input_ids = input_ids[:, hi:hi+args.step_size]
                # chunk_attention_mask = attention_mask[:, hi:hi+args.step_size]

                # # Advance the left end
                # left_input_ids = input_ids[:, lo:lo+args.step_size]
                # left_attention_mask = attention_mask[:, lo:lo+args.step_size]
                # left_outputs = model(
                #     input_ids=left_input_ids,
                #     attention_mask=left_attention_mask,
                #     states=left_states,
                #     return_comps=True,
                # )
                # left_states = left_outputs.states

                # # Advance the left decay term
                # left_comps = left_outputs.comps
                # A = torch.stack([layer_comps['A'] for layer_comps in left_comps], dim=0).to(dtype=torch.float32)  # (L, nheads)
                # left_dt = torch.stack([layer_comps['dt'] for layer_comps in left_comps], dim=0).to(dtype=torch.float32)  # (L, bsz, T, nheads)
                # left_dt_sum = left_dt.sum(dim=2)  # (L, bsz, nheads)
                # window_dt_sum -= left_dt_sum

                # # left_step_decay_log = torch.einsum('lbth,lh -> lbth', left_dt, left_A)  # (L, bsz, T, nheads)
                # dA_log = torch.einsum('lbh,lh -> lbh', window_dt_sum, A)  # (L, bsz, nheads)
                # # \alpha_{t-r+1:t - step_size}
                # dA = torch.exp(dA_log)

                # # print(f"{dA[34] = }")
                # # print(f"{dA.shape = }")
                # # print(f"{dA_log[34] = }")

                # # Compute h_t - \alpha_{t-r+1:t - step_size} h_{t-step_size - 1}
                # left_states = left_states.scale_ssm(dA)
                # window_states = right_states.subtract(left_states)
                # outputs = model(
                #     input_ids=chunk_input_ids,
                #     attention_mask=chunk_attention_mask,
                #     states=window_states,
                #     return_comps=True,
                # )
                # logits = outputs.logits  # (bsz, C, V)

                # # Advance the right end
                # right_input_ids = input_ids[:, hi:hi+args.step_size]
                # right_attention_mask = attention_mask[:, hi:hi+args.step_size]

                lo += args.step_size
                hi += args.step_size
                window_input_ids = input_ids[:, lo:hi]
                window_attention_mask = attention_mask[:, lo:hi]

                logits = model(
                    input_ids=window_input_ids,
                    attention_mask=window_attention_mask,
                    states=None,
                ).logits  # (B, T, V)
                # Compute per token loss on the rightmost `args.step_size` tokens
                labels = window_input_ids[:, - args.step_size + 1 :] # (bsz, T - 1)
                # preds = logits[:, -args.step_size : -1].transpose(1, 2)  # (bsz, V, T - 1)
                print(f"{labels.shape}, {preds.shape}")
                loss = loss_fn(preds, labels)  # (bsz, T - 1)
                all_loss.append(loss)

                # right_outputs = model(
                #     input_ids=right_input_ids,
                #     attention_mask=right_attention_mask,
                #     states=right_states,
                #     return_comps=True,
                # )
                # right_states = right_outputs.states

                # # Advance the decay term
                # right_comps = right_outputs.comps
                # right_dt = torch.stack([layer_comps['dt'] for layer_comps in right_comps], dim=0)  # (L, bsz, step_size, nheads)
                # right_dt_sum = right_dt.sum(dim=2)  # (L, bsz, nheads)
                # window_dt_sum += right_dt_sum


            loss = torch.cat(all_loss, dim=1)  # (bsz, T - 1)
            per_token_loss = loss.float().cpu()
            print(per_token_loss.shape)
        else:
            chunks = torch.split(input_ids, args.chunk_size * 1024, dim=1)  # (bsz, n_chunks, chunk_size)
            n_chunks = len(chunks)
            all_loss = []
            states = None
            for chunk_i, chunk in enumerate(chunks):
                print(f"{chunk_i}/{n_chunks}, {chunk.shape = }")
                outputs = model(
                    input_ids=chunk,
                    attention_mask=attention_mask,
                    state=states,
                )
                states = outputs.state
                logits = outputs.logits  # (bsz, C, V)
                # Compute per token loss
                labels = chunk[:, 1:]  # (bsz, C - 1)
                preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, C - 1)
                loss = loss_fn(preds, labels)  # (bsz, C - 1)
                all_loss.append(loss)

            loss = torch.cat(all_loss, dim=1)  # (bsz, T - 1)
            per_token_loss = loss.float().cpu()
            print(per_token_loss.shape)
            # outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            # logits = outputs.logits

            # # Compute per token loss
            # labels = input_ids[:, 1:]  # (bsz, T - 1)
            # preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, T - 1)
            # loss = loss_fn(preds, labels)  # (bsz, T - 1)
            # per_token_loss = loss.float().cpu()
        print(f"Per token loss: {per_token_loss}")
    return per_token_loss


def main():
    args = Args().parse_args()
    print(args)
    args.pretrained_path = args.pretrained_path.rstrip('/')
    ckpt_name = args.pretrained_path.replace('/', '--')
    output_dir = Path('result_per_token_sliding_window') / ckpt_name / args.prompt_name / f'{args.window_size}-{args.step_size}'
    output_dir.mkdir(exist_ok=True, parents=True)
    args.save(output_dir / 'args.json')

    cache_path = output_dir / 'per_token_loss.pt'
    print(f"Cache path: {cache_path}")
    if cache_path.exists() and not args.overwrite:
        print(f"Loading cached result from {cache_path}")
        with open(cache_path, 'rb') as f:
            per_token_loss = torch.load(f)
    else:
        print("Loading tokenizer and model")
        model, tokenizer = get_model(
            args.pretrained_path,
            args.tok_path,
            device=args.device,
            model_name=args.model_name,
        )
        print("========== finish loading =========")
        n_params = get_param_count(model)
        n_non_embed_params = get_non_embed_param_count(model)
        print(f"Param count: {n_params:,}, non-embedding: {n_non_embed_params:,}")
        print("=======================================================")

        prompt = get_long_prompt(args.prompt_name)
        tokens = tokenizer.tokenize(prompt)
        print(f"Prompt length: {len(tokens)}")
        # We need to repeat the prompt to ensure the output length is at least max_len
        n_reps = (args.max_len * 1024 - 1) // len(tokens) + 1
        prompt = prompt * n_reps
        if args.verbose:
            print("======= prompt =========")
            print(prompt[:1000])
            print('--------------------')
            print(prompt[-1000:])
            print("========================")
        per_token_loss = compute_per_token_loss(args, model, tokenizer, prompt)
        print(f"Caching result to: {cache_path}")
        torch.save(per_token_loss, cache_path)

    print(f"Per token loss: {per_token_loss}")
    print(f"{per_token_loss.shape = }")
    # Average over batch size
    per_token_loss = per_token_loss.mean(dim=0)  # (T - 1)
    n_tokens = per_token_loss.shape[0]
    # Bucket average
    buckets = torch.split(per_token_loss, args.bucket_size)  # (n_buckets, bucket_size)
    per_bucket_loss = torch.stack([bucket.mean() for bucket in buckets])  # (n_buckets)
    # The first and last bucket have large variations, so we discard it.
    per_bucket_loss = per_bucket_loss[1:-1]
    xs = torch.arange(1, len(per_bucket_loss) + 1) * args.bucket_size + args.bucket_size // 2
    
    # positions = list(range(len(per_bucket_loss)))
    plt.figure(figsize=(2.3, 2.3))
    plt.plot(xs, per_bucket_loss)
    plt.axvline(x=args.train_len * 1024, color='r', linestyle='--')
    plt.xlim((args.bucket_size, args.xmax * 1024))

    plt.xlabel(r'Token position ($t$)')
    plt.ylabel('Loss')
    # plt.xlim(0, 30000)
    # plt.ylim(2.2, 2.8)
    if args.ylog:
        plt.yscale('log')
    if args.xlog:
        plt.xscale('log')
    plt.grid(True)

    plt.tight_layout()

    dst_path = output_dir / f'per_token_loss-xmax{args.xmax}-bucket{args.bucket_size}-ylog{args.xlog}-xlog{args.ylog}.pdf'
    print(f"Saving plot to {dst_path}")
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    


if __name__ == "__main__":
    main()
