import os

os.environ["OMP_NUM_THREADS"] = "1"  # this is necessary to parallelize the kmeans

import argparse
import json
import pickle, random

import numpy as np
import torch
from sklearn.cluster import KMeans
from model_parse import get_module_names, parse_model, get_modules, load_model, get_layers
from tqdm import tqdm
from multiprocessing import Pool
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from datasets import load_dataset

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, help="path of the hf model")
parser.add_argument("--model_type", type=str, default=None, help="model type", choices=["llama", "opt","qwen"])
parser.add_argument("--bit", type=float, default=3, help="bitwidth")
parser.add_argument("--layerout", type=str, required=None, help="path to layer activation")
parser.add_argument("--sensitivity", type=float, default=0, help="sensitivity for outlier extraction")
parser.add_argument("--tasks", default=['wiki','c4'], type=str)
parser.add_argument('--device', type=int, default=0, help="GPU device ID, use 0 for 'cuda:0', 1 for 'cuda:1', etc.")


def remove_outliers_by_activation(
    model,
    activation,
    sensitivity,
    name_dict
):
    module_names = list(model.keys())
    outlier_weights = [[0 for _ in range(len(module_names))]]
    total_outliers = 0
    total_weights = 0


    def _body(act, weight):
        num_channel = int(act.shape[-1]* sensitivity / 100)
        act_max = torch.norm(act.view(-1, act.shape[-1]), dim=0)
        thres = act_max.topk(k=num_channel).values[-1]
        t = act_max > thres
        t_expanded = t.unsqueeze(0).expand(weight.shape[0], -1)
        outlier_weight = weight * t_expanded
        weight = weight * ~t_expanded


        return weight.to(weight.dtype), outlier_weight, t_expanded.sum().item(), t_expanded.numel()

    for i, _name in enumerate(module_names):
        weight = model[_name].to(torch.float)
        act= activation[name_dict[_name]]
        new_weight, outlier_weight, _total_outliers, _total_weights = _body(
            act, weight
        )
        model[_name] = new_weight
        total_outliers += _total_outliers
        total_weights += _total_weights
        outlier_weights[0][i] += outlier_weight

    return outlier_weights
def round_to_nearest_pole_sim(w, poles):
    """
    w: weight values (1d vector)
    poles: tuple of values

    Round the numbers in w to the nearest value in poles.
    """
    stack = []
    for c in poles:
        diff = (w - c).abs()
        stack.append(diff)
    diff = torch.stack(stack)
    idx = diff.argmin(axis=0)
    aug = 0
    for i, c in enumerate(poles):
        aug += (idx == i) * c
    return aug

def per_bit_channel(act,avg_bit):
    act_max = torch.norm(act.view(-1, act.shape[-1]), dim=0)
    bit_per_channel = torch.full_like(act_max, 3)
    if avg_bit==3:
        s_thr = torch.quantile(act_max.float(), 0.01).item()
        b_thr = torch.quantile(act_max.float(), 0.985).item()
        bit_per_channel[act_max > b_thr] = 4
        bit_per_channel[act_max < s_thr] = 2
    elif avg_bit >3:
        b_thr = torch.quantile(act_max.float(), max((1-(avg_bit-3))-0.05,0)).item()
        bit_per_channel[act_max > b_thr] = 4
    else:
        s_thr = torch.quantile(act_max.float(), 3-avg_bit).item()
        bit_per_channel[act_max < s_thr] = 2
    return  bit_per_channel

def extract(sensitivity, weight):
    num_outliers = int(weight.numel() * sensitivity / 100)
    thres = weight.reshape(-1).topk(k=num_outliers).values[-1]
    t= weight > thres

    return t

def kmeans_fit(row_data):
    weights_np, sample_weight, n_cluster = row_data
    kmeans = KMeans(n_clusters=n_cluster, random_state=0, n_init="auto", max_iter=50).fit(weights_np,
                                                                                          sample_weight=sample_weight)
    return kmeans.cluster_centers_.reshape(-1), np.cast["byte"](kmeans.labels_)


if __name__ == "__main__":
    args = parser.parse_args()
    device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")

    model_type = args.model_type
    if model_type == "llama" or model_type == "qwen" :
        name_dict = {"q": 'self_attn.q_proj', "k": 'self_attn.k_proj', "v": 'self_attn.v_proj',
                     "o": 'self_attn.o_proj', "gate":"mlp.gate_proj", "up": "mlp.up_proj", "down": "mlp.down_proj"}
    else:
        name_dict = {"q": 'self_attn.q_proj', "k": 'self_attn.k_proj', "v": 'self_attn.v_proj',
                     "o": 'self_attn.out_proj', "up": 'fc1', "down": 'fc2'}
    model = load_model(args.model_path, model_type)
    enc = AutoTokenizer.from_pretrained(
        args.model_path, use_fast=False, trust_remote_code=True
    )

    model = model.cpu()
    layers = get_layers(model, model_type)

    print(f"Quantizing layers")
    pool = Pool(os.cpu_count())

    for layer_id in range(len(layers)):
        layer = layers[layer_id]
        model_layer = {}
        modules = get_modules(layer, model_type)
        module_names = get_module_names(model_type)

        for lin, name in zip(modules, module_names):
            model_layer[name] = lin.weight.data

        try:
            with open(f"{args.layerout}/l{layer_id}.pkl", "rb") as f:
                layer_act = pickle.load(f)
        except:
            raise Exception(f"Needs layer activation file at {args.layerout}")

        outliers = remove_outliers_by_activation(
            model=model_layer,
            activation=layer_act,
            sensitivity=args.sensitivity,
            name_dict=name_dict)

        outliers = outliers[0]
        name_idx = 0
        for name in module_names:
            module_weight = model_layer[name]
            _weights_np = module_weight.float().numpy()
            act = layer_act[name_dict[name]]
            bit_per_channel = per_bit_channel(act, args.bit)

            kmeans_tasks = []
            for i in range(module_weight.shape[1]):
                n_cluster = 2 ** int(bit_per_channel[i])
                weights_np_temp = _weights_np[:,i]
                weights_np = weights_np_temp.reshape(-1, 1)
                sample_weight = np.ones_like(weights_np_temp)
                kmeans_tasks.append((weights_np, sample_weight, n_cluster))
            kmeans_results = list(tqdm(pool.imap(kmeans_fit, kmeans_tasks), total=len(kmeans_tasks)))
            intweight = modules[name_idx].weight.data.clone()
            for channel in range(module_weight.shape[1]):
                centers, labels = kmeans_results[channel]
                intweight[:,channel] = torch.tensor([centers[i] for i in labels])
            out_idx = extract(0.05, (module_weight.float()-intweight).abs())
            outliers[name_idx] += module_weight.float()*out_idx
            remain_weight = module_weight.float()* ~out_idx
            kmeans_tasks = []
            for i in range(module_weight.shape[1]):
                n_cluster = 2 ** int(bit_per_channel[i])
                weights_np_temp = remain_weight.numpy()[:,i]
                weights_np = weights_np_temp.reshape(-1, 1)
                sample_weight = np.ones_like(weights_np_temp)
                kmeans_tasks.append((weights_np, sample_weight, n_cluster))
            kmeans_results = list(tqdm(pool.imap(kmeans_fit, kmeans_tasks), total=len(kmeans_tasks)))

            intweight = modules[name_idx].weight.data.clone()
            for channel in range(module_weight.shape[1]):
                centers, labels = kmeans_results[channel]
                intweight[:,channel] = torch.tensor([centers[i] for i in labels])
                zero_mapping = round_to_nearest_pole_sim(torch.zeros(1), centers)
                nonzero_vals = torch.nonzero(outliers[name_idx][:, channel])
                outliers_channel = outliers[name_idx][:, channel]
                outliers_channel[nonzero_vals] -= zero_mapping
                outliers[name_idx][:, channel] = outliers_channel


            if model_type == 'qwen' or model_type == 'llama':
                if "llama2" in args.model:
                    modules[name_idx].weight.data = (outliers[name_idx] + intweight).half()
                else:
                    modules[name_idx].weight.data = (outliers[name_idx] + intweight).bfloat16()
            else:
                modules[name_idx].weight.data = (outliers[name_idx] + intweight).half()
            name_idx += 1
    pool.close()
    pool.join()


    if args.tasks is not None:
        qt_result={}
        for task in args.tasks:
            if "wiki" in task:
                testenc = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
                testenc = enc("\n\n".join(testenc["text"]), return_tensors="pt")
                model.seqlen = 2048
                model = model.to(device)
                testenc = testenc.input_ids.to(model.device)
                nsamples = testenc.numel() // model.seqlen
                model = model.eval()

                nlls = []
                for i in tqdm(range(nsamples), desc="evaluating..."):
                    batch = testenc[:, (i * model.seqlen): ((i + 1) * model.seqlen)].to(
                        model.device
                    )
                    with torch.no_grad():
                        lm_logits = model(batch).logits
                    shift_logits = lm_logits[:, :-1, :].contiguous().float()
                    shift_labels = testenc[
                                   :, (i * model.seqlen): ((i + 1) * model.seqlen)
                                   ][:, 1:]
                    loss_fct = torch.nn.CrossEntropyLoss()
                    loss = loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
                    )
                    neg_log_likelihood = loss.float() * model.seqlen
                    nlls.append(neg_log_likelihood)

                ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
                qt_result["wiki_ppl"]=ppl.item()

            elif "c4" in task:
                model.seqlen = 2048


                class TokenizerWrapper:
                    def __init__(self, input_ids):
                        self.input_ids = input_ids


                testdata = load_dataset(
                    'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
                    split='validation')

                random.seed(0)
                testenc = []
                for _ in range(256):
                    while True:
                        i = random.randint(0, len(testdata) - 1)
                        tmp = enc(testdata[i]["text"], return_tensors="pt")
                        if tmp.input_ids.shape[1] > model.seqlen:
                            break
                    # print (tmp.input_ids.shape,model.seqlen)
                    i = random.randint(0, tmp.input_ids.shape[1] - model.seqlen - 1)
                    j = i + model.seqlen
                    testenc.append(tmp.input_ids[:, i:j])
                testenc = torch.hstack(testenc)


                testenc = TokenizerWrapper(testenc)
                model = model.to(device)
                testenc = testenc.input_ids.to(model.device)
                model = model.eval()
                nsamples = testenc.numel() // model.seqlen

                nlls = []
                for i in tqdm(range(nsamples), desc="evaluating..."):
                    batch = testenc[:, (i * model.seqlen): ((i + 1) * model.seqlen)].to(
                        model.device
                    )
                    with torch.no_grad():
                        lm_logits = model(batch).logits
                    shift_logits = lm_logits[:, :-1, :].contiguous().float()
                    shift_labels = testenc[
                                   :, (i * model.seqlen): ((i + 1) * model.seqlen)
                                   ][:, 1:]
                    loss_fct = torch.nn.CrossEntropyLoss()
                    loss = loss_fct(
                        shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
                    )
                    neg_log_likelihood = loss.float() * model.seqlen
                    nlls.append(neg_log_likelihood)

                ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
                qt_result["c4_ppl"] = ppl.item()
        print(qt_result)