
import functools
import torch
import torch.nn as nn
from tqdm import tqdm
import argparse
import os
import sys
import math
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from rotation.model_utils import get_model
from utils import calib
from utils.common import *
from utils.common import distribute_model


def get_act_scales(model, testloader):
    model.eval()
    distribute_model(model)
    device = next(model.parameters()).device
    act_scales = {}

    def stat_tensor(name, tensor):
        hidden_dim = tensor.shape[-1]
        tensor = tensor.view(-1, hidden_dim).abs().detach()
        comming_max = torch.max(tensor, dim=0)[0].float().cpu()
        if name in act_scales:
            act_scales[name] = torch.max(act_scales[name], comming_max)
        else:
            act_scales[name] = comming_max

    def stat_input_hook(m, x, y, name):
        if isinstance(x, tuple):
            x = x[0]
        stat_tensor(name, x)

    hooks = []
    for name, m in model.named_modules():
        if isinstance(m, nn.Linear):
            hooks.append(
                m.register_forward_hook(functools.partial(stat_input_hook, name=name))
            )

    for batch in tqdm(testloader, desc="Calculating activation scales"):
        input_ids = batch[0].to(device)
        model(input_ids)

    for h in hooks:
        h.remove()

    return act_scales


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model", type=str, default="meta-llama/Meta-Llama-3-8B", help="model name"
    )
    parser.add_argument(
        "--output-path",
        type=str,
        default="./smooth_scales/",
        help="where to save the act scales",
    )
    parser.add_argument(
        "--calib_dataset",
        type=str,
        default="wikitext2",
        help="Name of the calibration dataset, e.g., wikitext2, ptb, c4",
    )
    parser.add_argument("--nsamples", type=int, default=512)
    parser.add_argument("--seq-len", type=int, default=512)
    args = parser.parse_args()
    return args


@torch.no_grad()
def main():
    args = parse_args()
    model = get_model(args.model)
    trainloader = calib.get_loaders(args.calib_dataset, nsamples=args.nsamples,
                                    model=args.model, eval_mode=False)
    act_scales = get_act_scales(model, trainloader)
    args.output_path = os.path.join(args.output_path, f"{args.model.split('/')[-1]}-smooth-scales.pt")
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    torch.save(act_scales, args.output_path)
    print(f"Activation scales saved to {args.output_path}")


if __name__ == "__main__":
    main()
