from __future__ import annotations

import argparse
import os
import pickle
from functools import partial
from typing import Any

import torch
import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase


def tokenize_and_chunk_fn(
    examples: dict[str, list[Any]], tokenizer: PreTrainedTokenizerBase, max_length: int
) -> dict[str, list[list[int]]]:
    buffer, outputs = [], []
    for tokens in tokenizer(examples["text"]).input_ids:
        buffer += tokens
        while len(buffer) > max_length:
            outputs.append(torch.tensor(buffer[:max_length], dtype=torch.int32))
            buffer = buffer[max_length:]
    return {"tokens": outputs}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("--max-length", type=int, default=2048)
    parser.add_argument("--threshold", type=float, default=0.1)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--shuffle-seed", type=int, default=42)
    parser.add_argument("--limit-ratio", type=float, default=0.1)
    parser.add_argument("--dataset-type", default="nlp")
    args = parser.parse_args()

    model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True)
    model = model.cuda().bfloat16().eval().requires_grad_(False)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token

    # Prepare C4/TheStack validation dataset for collecting expert routing examples.
    # Note that all small sequences will be packed into a single sequence to make all
    # `--max-length` sequences do not have any useless (e.g. padding) tokens.
    if args.dataset_type == "nlp":
        dataset = load_dataset(
            "allenai/c4",
            data_files={"validation": "en/c4-validation.*.json.gz"},
            split="validation",
        )
    elif args.dataset_type == "code":
        dataset = load_dataset("codecomplete/starcoderdata_0.001", split="train")
    else:
        raise NotImplementedError(f"{args.dataset_type} is not supported")

    dataset = dataset.shuffle(seed=args.shuffle_seed).map(
        partial(tokenize_and_chunk_fn, tokenizer=tokenizer, max_length=args.max_length),
        batched=True,
        remove_columns=dataset.column_names,
    )
    dataloader = DataLoader(dataset.with_format("torch"), args.batch_size)
    total_batches = int(len(dataloader) * args.limit_ratio)
    num_routings = model.config.num_hidden_layers // model.config.moe_groups

    input_tokens, token_routings = [], [[] for _ in range(num_routings)]
    for i, batch in zip(tqdm.trange(total_batches), dataloader):
        outputs = model(
            batch["tokens"].clone().cuda(),
            output_hidden_states=True,
            output_router_probs=True,
        )
        input_tokens.append(batch["tokens"])

        # Collect the expert routing probabilites and accumulate across the multi-heads.
        g1s = [x[0] for x in outputs.router_probs if x is not None]
        g2s = [x[1] for x in outputs.router_probs if x is not None]

        for j, (g1, g2) in enumerate(zip(g1s, g2s)):
            g = torch.einsum("bthi,bthj->btij", g1, g2).flatten(-2)
            for b in range(args.batch_size):
                token_routings[j].append([])
                for t in range(args.max_length):
                    # The routing with less than `--threshold` will be removed to reduce
                    # the memory consumption. As mentioned above, embedding
                    # accumulations will also ignore the weak routing tokens.
                    gk = (g[b, t] > args.threshold).argwhere().squeeze(-1)
                    gk, gv = gk.tolist(), g[b, t][gk].tolist()
                    token_routings[j][-1].append(dict(zip(gk, gv)))

    # Save the example inputs and their routing information to the model directory.
    os.makedirs((outdir := os.path.join(args.model, "interpretation")), exist_ok=True)
    for i, routings in enumerate(token_routings):
        with open(os.path.join(outdir, f"routings-{i}.pkl"), "wb") as fp:
            pickle.dump(routings, fp)
    torch.save(torch.cat(input_tokens), os.path.join(outdir, "inputs.pt"))
