import json
import os
import pickle
import time

import torch
import torch.nn as nn
import transformers
from quant import *





def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(
            find_layers(
                child, layers=layers, name=name + "." + name1 if name != "" else name1
            )
        )
    return res


@torch.no_grad()
def model_sequential(model, folder, include_sparse, model_type):
    print("Starting ...")
    if model_type == 'llama':
        layers = model.model.layers
    else:
        layers = model.model.decoder.layers

    quantizers = {}
    for i in range(len(layers)):
        with open(f"{folder}/lut/l{i}.pkl", "rb") as f:
            lut_layer = pickle.load(f)

        if include_sparse:
            with open(f"{folder}/outliers/l{i}.pkl", "rb") as f:
                outlier_list_layer = pickle.load(f)

        if model_type == "opt":
            sequential_lut = ["q", "k", "v", "o", "up", "down"]
            sequential_lut_real_name = {
                "q": "self_attn.q_proj",
                "k": "self_attn.k_proj",
                "v": "self_attn.v_proj",
                "o": "self_attn.out_proj",
                "up": "fc1",
                "down": "fc2",
            }
            outlier_index = {"q": 0, "k": 1, "v": 2, "o": 3, "up": 4, "down": 5}
        else:
            sequential_lut = ["q", "k", "v", "o", "gate", "up", "down"]
            sequential_lut_real_name = {
                "q": "self_attn.q_proj",
                "k": "self_attn.k_proj",
                "v": "self_attn.v_proj",
                "o": "self_attn.o_proj",
                "gate": "mlp.gate_proj",
                "up": "mlp.up_proj",
                "down": "mlp.down_proj",
            }
            outlier_index = {"q": 0, "k": 1, "v": 2, "o": 3, "gate": 4, "up": 5, "down": 6}
        for s in sequential_lut:
            lut, bit_per_channel = lut_layer[s]
            if include_sparse:
                idx = outlier_index[s]
                outliers = outlier_list_layer[idx]
            else:
                outliers = None
            name = sequential_lut_real_name[s]
            if model_type == "opt":
                quantizers["model.decoder.layers.%d.%s" % (i, name)] = [lut, bit_per_channel, outliers]
            else:
                quantizers["model.layers.%d.%s" % (i, name)] = [lut, bit_per_channel, outliers]

    return quantizers


def model_pack(
        model,
        quantizers,
        wbits,
        include_sparse,
        balanced,
        num_nonzero_per_thread,
):
    layers = find_layers(model)
    layers = {n: layers[n] for n in quantizers}
    make_quant_lut(
        model,
        quantizers,
        wbits,
        include_sparse=include_sparse,
        balanced=balanced,
    )

    qlayers = find_layers(model, [QuantLinearLUT])
    print("Packing ...")
    sparsedict = {}

    for name in qlayers:
        print(name)
        lookup_table = quantizers[name]
        layers[name].cpu()
        qlayers[name].pack2(
            layers[name],
            lookup_table,
            include_sparse,
            num_nonzero_per_thread=num_nonzero_per_thread,
        )
        if include_sparse:
            sparsedict[name] = qlayers[name].vals.shape[-1]

    print("Done.")
    return model, sparsedict


if __name__ == "__main__":
    import argparse

    from squeezellm.datautils import *

    parser = argparse.ArgumentParser()

    parser.add_argument("--model", type=str, help="llama model to load")
    parser.add_argument(
        "--model_type", type=str, default="llama", help="model type", choices=["llama", "opt"]
    )
    parser.add_argument(
        "--wbits",
        type=int,
        default=3,
        choices=[2, 3, 4],
        help="#bits to use for quantization; use 16 for evaluating base model.",
    )
    parser.add_argument(
        "--save",
        type=str,
        required=True,
        help="Save quantized checkpoint under this name.",
    )

    # sparse args
    parser.add_argument(
        "--folder",
        type=str,
        default="",
        help="Path to folder containing luts and outliers.",
    )
    parser.add_argument(
        "--include_sparse",
        action="store_true",
        help="Whether loaded checkpoint has sparse matrix.",
    )

    # balanced kernel arguments
    parser.add_argument(
        '--balanced', action='store_true',
        help='Whether to use balanced sparse kernel.'
    )
    parser.add_argument(
        '--num_nonzero_per_thread', type=int, default=10,
        help='Num nonzeros assigned to each thread.'
    )

    args = parser.parse_args()

    args.folder = "look up table folder"
    model = transformers.AutoModelForCausalLM.from_pretrained(
        args.model, trust_remote_code=True, torch_dtype="auto"
    )
    model.eval()
    model_type = args.model_type

    tick = time.time()
    quantizers = model_sequential(
        model=model,
        folder=args.folder,
        include_sparse=args.include_sparse,
        model_type=args.model_type,
    )

    tick = time.time()
    model, numvals = model_pack(
        model=model,
        quantizers=quantizers,
        wbits=args.wbits,
        include_sparse=args.include_sparse,
        balanced=args.balanced,
        num_nonzero_per_thread=args.num_nonzero_per_thread,
    )

    model_dict = model.state_dict()

    if args.include_sparse:
        # need to merge in sparse dict
        for k, v in numvals.items():
            model_dict["sparse_threshold." + k] = v
    lut_folder = "/quant_model/"

    # save model
    if not os.path.exists(lut_folder):
        os.makedirs(lut_folder)
    torch.save(model_dict, args.save)

    # get directory to save quant_config
    directory = os.path.dirname(args.save)
    data = {"wbits": args.wbits}
    output_fn = os.path.join(directory, "quant_config.json")

    # save quant_config
    with open(output_fn, "w") as f:
        json.dump(data, f, indent=4)
