from pathlib import Path
import sys

import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from get_prompt import get_long_prompt
from tap import Tap

sys.path.append('..')
from utils import get_non_embed_param_count, get_param_count


class Args(Tap):
    model_name: str = ''
    # pretrained_path: str = '/Users/donny/donny/research/ckpts/rwkv6-world-1.6b'
    pretrained_path: str = '../../ckpts/mamba/mamba2-780m'
    tok_path: str | None = '../tokenizers/mamba-tok'
    prompt_name: str = 'nextlines'
    device: str = 'cuda'
    max_len: int = 128
    overwrite: int = 0
    train_len: int = 8
    xmax: int = 64
    ymax: int = None
    xlog: int = 0
    ylog: int = 1
    ppl: int = 0

    # The input is chunked before feeding to the model to avoid OOM.
    chunk_size: int = 16

    # The loss is averaged over a window of `bucket_size` tokens.
    bucket_size: int = 256

    # Whether to use sliding window, -1 means no sliding window.
    window_size: int = -1

    verbose: int = 0

    state_norm: float = 0.0
    dt_mult: float = 0.0
    da_mult: float = 0.0
    a_mult: float = 0.0
    b_mult: float = 0.0

    dynamic_norm: float = 0.0


def get_model(pretrained_path, tok_path=None, device='cuda', dtype=torch.float32, model_name: str = ''):
    if tok_path is None:
        tok_path = pretrained_path
    if model_name == 'mamba2' or "mamba2" in str(pretrained_path):
        from modeling.mamba2.modeling_mamba2_torch import Mamba2ForCausalLM
        print(f"Loading tokenizer from {tok_path}")
        tokenizer = AutoTokenizer.from_pretrained(tok_path, trust_remote_code=True)
        print(f"Loading model from {pretrained_path}...")
        model = Mamba2ForCausalLM.from_pretrained(
            pretrained_path,
            device=device,
        ).to(dtype=dtype)
    elif 'rwkv6' in str(pretrained_path):
        print(f"Loading tokenizer from {tok_path}")
        tokenizer = AutoTokenizer.from_pretrained(pretrained_path, trust_remote_code=True)
        print(f"Loading model from {pretrained_path}...")
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_path,
            trust_remote_code=True,
        ).to(device=device, dtype=dtype)
    else:
        raise ValueError(f"Unknown model type: {pretrained_path}")
    return model, tokenizer


def compute_per_token_loss(args: Args, model, tokenizer, prompt: str):
    inputs = tokenizer(prompt, return_tensors='pt')
    input_ids: Tensor = inputs.input_ids.to(device=args.device)
    attention_mask: Tensor = inputs.attention_mask.to(device=args.device)
    print("Input ids shape:", input_ids.shape)

    # Run one forward pass and compute per token loss
    with torch.no_grad():
        loss_fn = nn.CrossEntropyLoss(reduction='none')
        if 'mamba2' in args.pretrained_path:
            chunks = torch.split(input_ids, args.chunk_size * 1024, dim=1)  # (bsz, n_chunks, chunk_size)
            n_chunks = len(chunks)
            all_loss = []
            states = None
            for chunk_i, chunk in enumerate(chunks[:-1]):
                print(f"{chunk_i}/{n_chunks}, {chunk.shape = }")
                outputs = model(
                    input_ids=chunk,
                    # attention_mask=attention_mask,
                    states=states,
                )
                states = outputs['states']
                if args.state_norm > 0.0:
                    for li in range(len(states)):
                        layer_state = states[li][1]  # (bsz, n_heads, P, N)
                        B, H, P, N = layer_state.shape
                        state_norm = layer_state.norm(dim=(-2, -1))  # (bsz, n_heads)
                        for bi in range(B):
                            for hi in range(H):
                                if state_norm[bi, hi] > args.state_norm:
                                    conv_state = states[li][0]
                                    normed_ssm_state = layer_state[bi, hi] / state_norm[bi, hi] * args.state_norm
                                    # print(f"Normed state norm: {normed_ssm_state.norm()}")
                                    states[li][1][bi, hi] = normed_ssm_state

                logits = outputs['logits']  # (bsz, C, V)
                # Compute per token loss
                labels = chunk[:, 1:]  # (bsz, C - 1)
                preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, C - 1)
                loss = loss_fn(preds, labels)  # (bsz, C - 1)
                all_loss.append(loss)

            # The last chunk needs to be processed one token each time,
            # because the parallel implementation requires the input to be
            # a multiple of chunk_size = 64.
            chunk = chunks[-1]  # (bsz, chunk_size)
            for i in range(chunk.shape[1] - 1):
                outputs = model(
                    input_ids=chunk[:, i:i+1],
                    states=states,
                    recurrent_mode=True,
                )
                states = outputs['states']
                logits = outputs['logits']  # (bsz, 1, V)
                labels = chunk[:, i+1:i+2]  # (bsz, 1)
                preds = logits.transpose(1, 2)  # (bsz, V, 1)
                # print(labels.shape, preds.shape)
                loss = loss_fn(preds, labels)  # (bsz, 1)
                all_loss.append(loss)

            loss = torch.cat(all_loss, dim=1)  # (bsz, T - 1)
            per_token_loss = loss.float().cpu()
            print(per_token_loss.shape)
        else:
            chunks = torch.split(input_ids, args.chunk_size * 1024, dim=1)  # (bsz, n_chunks, chunk_size)
            n_chunks = len(chunks)
            all_loss = []
            states = None
            for chunk_i, chunk in enumerate(chunks):
                print(f"{chunk_i}/{n_chunks}, {chunk.shape = }")
                outputs = model(
                    input_ids=chunk,
                    attention_mask=attention_mask,
                    state=states,
                )
                states = outputs.state
                logits = outputs.logits  # (bsz, C, V)
                # Compute per token loss
                labels = chunk[:, 1:]  # (bsz, C - 1)
                preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, C - 1)
                loss = loss_fn(preds, labels)  # (bsz, C - 1)
                all_loss.append(loss)

            loss = torch.cat(all_loss, dim=1)  # (bsz, T - 1)
            per_token_loss = loss.float().cpu()
            print(per_token_loss.shape)
            # outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            # logits = outputs.logits

            # # Compute per token loss
            # labels = input_ids[:, 1:]  # (bsz, T - 1)
            # preds = logits[:, :-1].transpose(1, 2)  # (bsz, V, T - 1)
            # loss = loss_fn(preds, labels)  # (bsz, T - 1)
            # per_token_loss = loss.float().cpu()
        print(f"Per token loss: {per_token_loss}")
    return per_token_loss


def main():
    args = Args().parse_args()

    args.pretrained_path = args.pretrained_path.rstrip('/')
    ckpt_name = args.pretrained_path.replace('/', '--')
    output_dir = (
        Path('result_per_token') / 
        f'{ckpt_name}-dt_mult{args.dt_mult}-da_mult{args.da_mult}-b_mult{args.b_mult}-a_mult{args.a_mult}-hnorm{args.state_norm}' / 
        args.prompt_name
    )
    output_dir.mkdir(exist_ok=True, parents=True)
    args.save(output_dir / 'args.json')

    cache_path = output_dir / 'per_token_loss.pt'
    print(f"Cache path: {cache_path}")
    if cache_path.exists() and not args.overwrite:
        print(f"Loading cached result from {cache_path}")
        with open(cache_path, 'rb') as f:
            per_token_loss = torch.load(f)
    else:
        print("Loading tokenizer and model")
        model, tokenizer = get_model(
            args.pretrained_path,
            args.tok_path,
            device=args.device,
            model_name=args.model_name,
        )
        print("========== finish loading =========")
        n_params = get_param_count(model)
        n_non_embed_params = get_non_embed_param_count(model)
        print(f"Param count: {n_params:,}, non-embedding: {n_non_embed_params:,}")
        print("=======================================================")

        if args.dt_mult != 0.0:
            print(f"Setting dt_mult to {args.dt_mult}")
            model.set_dt_mult(args.dt_mult)

        if args.da_mult != 0.0:
            print(f"Setting da_mult to {args.da_mult}")
            model.set_da_mult(args.da_mult)

        if args.b_mult != 0.0:
            print(f"Setting b_mult to {args.b_mult}")
            model.set_b_mult(args.b_mult)

        if args.a_mult != 0.0:
            print(f"Setting a_mult to {args.a_mult}")
            model.set_a_mult(args.a_mult)

        prompt = get_long_prompt(args.prompt_name)
        tokens = tokenizer.tokenize(prompt)
        if len(tokens) > args.max_len * 1024:
            prompt = ''.join(tokens[:args.max_len * 1024])
        else:
            # We need to repeat the prompt to ensure the output length is at least max_len
            n_reps = (args.max_len * 1024 - 1) // len(tokens) + 1
            prompt = prompt * n_reps
        if args.verbose:
            print("======= prompt =========")
            print(prompt[:1000])
            print('--------------------')
            print(prompt[-1000:])
            print("========================")
        per_token_loss = compute_per_token_loss(args, model, tokenizer, prompt)
        print(f"Caching result to: {cache_path}")
        torch.save(per_token_loss, cache_path)

    # Average over batch size
    per_token_loss = per_token_loss.mean(dim=0)  # (T - 1)
    # Bucket average
    buckets = torch.split(per_token_loss, args.bucket_size)  # (n_buckets, bucket_size)
    per_bucket_loss = torch.stack([bucket.mean() for bucket in buckets])  # (n_buckets)
    # The first and last bucket have large variations, so we discard it.
    ys = per_bucket_loss[1:-1]
    xs = torch.arange(1, len(ys) + 1) * args.bucket_size + args.bucket_size // 2

    if args.ppl:
        ys = torch.exp(ys)
    
    # positions = list(range(len(per_bucket_loss)))
    plt.figure(figsize=(2.3, 2.3))
    plt.plot(xs, ys)
    plt.axvline(x=args.train_len * 1024, color='r', linestyle='--')
    plt.xlim((args.bucket_size, args.xmax * 1024))

    if args.ymax is not None:
        plt.ylim(0, args.ymax)

    plt.xlabel(r'Token position ($t$)')
    if args.ppl:
        plt.ylabel('Perplexity')
    else:
        plt.ylabel('Loss')
    # plt.xlim(0, 30000)
    # plt.ylim(2.2, 2.8)
    if args.ylog:
        plt.yscale('log')
    if args.xlog:
        plt.xscale('log')
    plt.grid(True)

    plt.tight_layout()

    dst_path = output_dir / 'per_token_loss.pdf'
    print(f"Saving plot to {dst_path}")
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    


if __name__ == "__main__":
    main()
