import torch
import torch.nn as nn
import tqdm, pickle
import gc,os,random
import functools
from collections import defaultdict
import argparse
from transformers import AutoTokenizer, LlamaTokenizer
from datasets import load_dataset
from model_parse import (
    get_layers,
    get_module_names,
    get_modules,
    load_model,
)
def parse_model(model):
    if "opt" in model.lower():
        model_type = "opt"
    elif "mistral" in model.lower():
        model_type = "mistral"
    else:
        # additional rules should be added to support other models
        model_type = "llama"
    print(f"Model type : {model_type}")

    return model_type

def get_named_linears(module):
    return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def move_embed(model, device, args):
    if "llama" in args.model.lower():
        model.model.embed_tokens = model.model.embed_tokens.to(device)
        model.model.norm = model.model.norm.to(device)
    elif "opt" in args.model:
        model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
        model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
            device
        )
    elif "Qwen" in args.model or "qwen" in args.model:
        model.model.embed_tokens = model.model.embed_tokens.to(device)
        model.model.norm = model.model.norm.to(device)
    else:
        raise NotImplementedError(type(model))

parser = argparse.ArgumentParser()
parser.add_argument(
    "--output_path", type=str, default=None, help="chunk the model and store"
)
parser.add_argument("--model", type=str, help="model to load")
parser.add_argument("--n_sample", type=int, default=128,help="number of samples")
parser.add_argument('--device', type=int, default=0, help="GPU device ID, use 0 for 'cuda:0', 1 for 'cuda:1', etc.")

args = parser.parse_args()
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
model_type = parse_model(args.model)
print(model_type)
if not os.path.exists(args.output_path):
    os.makedirs(args.output_path)
model = load_model(args.model, model_type)
layers = get_layers(model, model_type)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)

traindata = load_dataset("allenai/c4",
        data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
        split="train")
trainloader = []
model.seqlen = 2048
nsamples = args.n_sample
for _ in range(nsamples):
    while True:
        i = random.randint(0, len(traindata) - 1)
        trainenc = tokenizer(traindata[i]["text"], return_tensors="pt")
        if trainenc.input_ids.shape[1] > model.seqlen:
            break
    i = random.randint(0, trainenc.input_ids.shape[1] - model.seqlen - 1)
    j = i + model.seqlen
    inp = trainenc.input_ids[:, i:j]
    tar = inp.clone()
    tar[:, :-1] = -100
    trainloader.append((inp, tar))


inps = []
layer_kwargs = {}

layers[0] = layers[0].to(device)
move_embed(model, "cuda", args)

class Catcher(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inp, **kwargs):
        inps.append(inp)
        layer_kwargs.update(kwargs)
        raise ValueError  # early exit to break later inference

# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
for batch in trainloader:
    try:
        model(batch[0].to(next(model.parameters()).device))
    except ValueError:
        pass

del trainloader
layers[0] = layers[0].module  # restore
inps = inps[0]

layers[0] = layers[0].cpu()
move_embed(model, "cpu", args)

torch.cuda.empty_cache()
gc.collect()

# solve layer by layer
for i in tqdm.tqdm(range(len(layers))):
    layer = layers[i]
    layer = layer.to(device)
    named_linears = get_named_linears(layer)

    # firstly, get input features of all linear layers
    def cache_input_hook(m, x, y, name, feat_dict):
        x = x[0]
        x = x.detach().cpu()
        feat_dict[name].append(x)

    input_feat = defaultdict(list)
    handles = []
    for name in named_linears:
        handles.append(
            named_linears[name].register_forward_hook(
                functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
            )
        )
    inps = inps.to(next(layer.parameters()).device)
    inps = layer(inps, **layer_kwargs)[0]
    for h in handles:
        h.remove()

    input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
    torch.cuda.empty_cache()
    with open(f"{args.output_path}/l{i}.pkl", "wb") as f:
        print(f"Saving layer activation to {args.output_path}/l{i}.pkl")
        pickle.dump(input_feat, f)

    layer = layer.cpu()
    del input_feat
    del layer
    torch.cuda.empty_cache()
    gc.collect()



