from __future__ import annotations

import argparse
import glob
import heapq
import os
import pickle
from dataclasses import dataclass, field
from typing import Any

import tqdm
from transformers import AutoConfig


@dataclass(order=True)
class PrioritizedItem:
    priority: int
    item: Any = field(compare=False)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("--topk", type=int, default=100)
    args = parser.parse_args()

    config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
    routing_filenames = os.path.join(args.model, "interpretation/routings-*.pkl")

    for n, filename in enumerate(sorted(glob.glob(routing_filenames))):
        with open(filename, "rb") as fp:
            routings = pickle.load(fp)

        examples = [[] for _ in range(config.moe_experts**2)]
        for i, routing in enumerate(tqdm.tqdm(routings)):
            for j, maps in enumerate(routing):
                for k, v in maps.items():
                    if len(examples[k]) < args.topk:
                        heapq.heappush(examples[k], PrioritizedItem(v, (i, j)))
                    elif v > examples[k][0].priority:
                        heapq.heapreplace(examples[k], PrioritizedItem(v, (i, j)))

        examples = [
            [(*x.item, x.priority) for x in sorted(xs)[::-1]] for xs in examples
        ]
        activated = [i for i, x in enumerate(examples) if len(x) == args.topk]
        print(f"[*] Routing Layer {n}")
        print(f"[*] Number of Fully Activated Experts: {len(activated)}")
        print(f"[*] Activation Ratio: {len(activated) / len(examples) * 100:.2f}%")

        # Save the top-k examples in each expert to the model directory.
        with open(filename.replace("routings", "examples"), "wb") as fp:
            pickle.dump(examples, fp)
