from __future__ import annotations

import argparse
import glob
import os
import pickle
import re
from collections import defaultdict

import torch
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("--batch-size", type=int, default=8)
    args = parser.parse_args()

    example_filenames = os.path.join(args.model, "interpretation/examples-*.pkl")
    input_tokens = torch.load(os.path.join(args.model, "interpretation/inputs.pt"))
    input_tokens = input_tokens.cuda()

    model = AutoModelForCausalLM.from_pretrained(args.model, trust_remote_code=True)
    model = model.cuda().bfloat16().eval().requires_grad_(False)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token

    for filename in glob.glob(example_filenames):
        routing_index = int(re.match(r".*/examples-(\d+)[.]pkl", filename).group(1))
        layer_start = routing_index * model.config.moe_groups
        layer_end = (routing_index + 1) * model.config.moe_groups

        with open(filename, "rb") as fp:
            sequence_mappings = defaultdict(dict)
            for i, examples in enumerate(pickle.load(fp)):
                for j, k, p in examples:
                    sequence_mappings[j][k] = (i, p)

        expert_embs = torch.zeros(model.config.moe_experts**2, model.config.hidden_size)
        expert_probs = torch.zeros(model.config.moe_experts**2)
        for i in tqdm.trange(0, input_tokens.size(0), args.batch_size):
            batch = input_tokens[i : i + args.batch_size]
            outputs = model(batch, output_hidden_states=True).hidden_states
            outputs = sum(outputs[layer_start:layer_end]) / model.config.moe_groups
            outputs = outputs.cpu().float()

            for j in range(args.batch_size):
                for t, (expert_index, expert_prob) in sequence_mappings[i + j].items():
                    expert_embs[expert_index] += outputs[j, t] * expert_prob
                    expert_probs[expert_index] += expert_prob

        filename = filename.replace(
            f"examples-{routing_index}.pkl",
            f"embeddings-{routing_index}.pt",
        )
        torch.save((expert_embs / (expert_probs[:, None] + 1e-10)).bfloat16(), filename)
