import time

import torch
import torch.nn as nn

from gptq import *
from modelutils import *
from quant import *
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_llama(model):
    import torch
    def skip(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = skip
    torch.nn.init.uniform_ = skip
    torch.nn.init.normal_ = skip
    
    model = AutoModelForCausalLM.from_pretrained(model, torch_dtype='auto')
    model.seqlen = 2048
    return model
def print_layer_bits(i, name, wbits):
    if 'block_sparse_moe.experts' in name:
        expert_num = name.split('.')[-2]
        print(f"Layer {i}, Expert {expert_num}, {name.split('.')[-1]}: {wbits} bits")
    else:
        print(f"Layer {i}, {name}: {wbits} bits")
@torch.no_grad()
def llama_sequential(model, dataloader, dev, bits_config_str):
    print('Starting ...')

    mixtral_bit = {}
    # main_bits = int(bits_config_str.split('.')[0].split('_')[1])
    
    # # Set default bits for all layers and experts
    # for i in range(32):  # Assuming 32 layers
    #     for j in range(8):  # 8 experts per layer
    #         for part in ['w1', 'w2', 'w3']:
    #             key = f"model.layers.{i}.block_sparse_moe.experts.{j}.{part}"
    #             mixtral_bit[key] = main_bits

    # Parse special expert bits
    # special_expert_bits = re.findall(r"exp_l(\d+)e(\d+)_(\d+)", bits_config_str)
    # for layer, expert, bits in special_expert_bits:
    #     for part in ['w1', 'w2', 'w3']:
    #         key = f"model.layers.{int(layer)}.block_sparse_moe.experts.{int(expert)}.{part}"
    #         mixtral_bit[key] = int(bits)
    # import re
    # special_expert_bits = re.findall(r"exp_l(\d+)e(\d+)w(\d+)_(\d+)", bits_config_str)
    # for layer, expert, part, bits in special_expert_bits:
    #     key = f"model.layers.{int(layer)}.block_sparse_moe.experts.{int(expert)}.w{int(part)}"
    #     mixtral_bit[key] = int(bits)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    
    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']

    print('Ready.')

    quantizers = {}
    for i in range(len(layers)):
        layer = layers[i].to(dev)
        full = find_layers(layer)

        if args.true_sequential:
            sequential = [
                ['self_attn.k_proj', 'self_attn.v_proj', 'self_attn.q_proj'],
                ['self_attn.o_proj'],
                [
                "block_sparse_moe.experts.0.w1",
                "block_sparse_moe.experts.1.w1",
                "block_sparse_moe.experts.2.w1",
                "block_sparse_moe.experts.3.w1",
                "block_sparse_moe.experts.4.w1",
                "block_sparse_moe.experts.5.w1",
                "block_sparse_moe.experts.6.w1",
                "block_sparse_moe.experts.7.w1",
                "block_sparse_moe.experts.0.w3",
                "block_sparse_moe.experts.1.w3",
                "block_sparse_moe.experts.2.w3",
                "block_sparse_moe.experts.3.w3",
                "block_sparse_moe.experts.4.w3",
                "block_sparse_moe.experts.5.w3",
                "block_sparse_moe.experts.6.w3",
                "block_sparse_moe.experts.7.w3",
            ],
            [
                "block_sparse_moe.experts.0.w2",
                "block_sparse_moe.experts.1.w2",
                "block_sparse_moe.experts.2.w2",
                "block_sparse_moe.experts.3.w2",
                "block_sparse_moe.experts.4.w2",
                "block_sparse_moe.experts.5.w2",
                "block_sparse_moe.experts.6.w2",
                "block_sparse_moe.experts.7.w2",
            ],
            ]
        else:
            sequential = [list(full.keys())]
       
        for names in sequential:
            subset = {n: full[n] for n in names}

            gptq = {}
            for name in subset:
                gptq[name] = GPTQ(subset[name])
                gptq[name].quantizer = Quantizer()

                # full_name = f"model.layers.{i}.{name}"
                # if 'block_sparse_moe.experts' in full_name:
                #     # For expert layers
                #     wbits = mixtral_bit.get(full_name, 2)  # Default to 2 bits if not specified
                # else:
                #     # For attention layers
                #     wbits = args.wbits
                print_layer_bits(i, name, wbits)
                gptq[name].quantizer.configure(
                    args.wbits, perchannel=True, sym=args.sym, mse=False
                )

            def add_batch(name):
                def tmp(_, inp, out):
                    gptq[name].add_batch(inp[0].data, out.data)
                return tmp
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(args.nsamples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
            for h in handles:
                h.remove()

            for name in subset:
                print(i, name)
                print('Quantizing ...')
                # import pdb;pdb.set_trace()
                gptq[name].fasterquant(
                    percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order, static_groups=args.static_groups
                )
                quantizers['model.layers.%d.%s' % (i, name)] = gptq[name].quantizer
                gptq[name].free()

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]

        layers[i] = layer.cpu()
        del layer
        del gptq 
        torch.cuda.empty_cache()

        inps, outs = outs, inps
        
    model.config.use_cache = use_cache
    
    return quantizers

@torch.no_grad()
def llama_eval(model, testenc, dev):
    print('Evaluating ...')

    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['attention_mask'] = kwargs['attention_mask']
            cache['position_ids'] = kwargs['position_ids']
            raise ValueError
    layers[0] = Catcher(layers[0])
    for i in range(nsamples):
        batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
        try:
            model(batch)
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']

    for i in range(len(layers)):
        print(i)
        layer = layers[i].to(dev)
        
        if args.nearest:
            subset = find_layers(layer)
            for name in subset:
                quantizer = Quantizer()
                quantizer.configure(
                    args.wbits, perchannel=True, sym=False, mse=False
                )
                W = subset[name].weight.data
                quantizer.find_params(W, weight=True)
                subset[name].weight.data = quantize(
                    W, quantizer.scale, quantizer.zero, quantizer.maxq
                ).to(next(iter(layer.parameters())).dtype)

        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0]
        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if model.model.norm is not None:
        model.model.norm = model.model.norm.to(dev)
    model.lm_head = model.lm_head.to(dev)

    testenc = testenc.to(dev)
    nlls = []
    for i in range(nsamples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[
            :, (i * model.seqlen):((i + 1) * model.seqlen)
        ][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    print(ppl.item())

    model.config.use_cache = use_cache

def llama_pack3(model, quantizers):
    layers = find_layers(model)
    layers = {n: layers[n] for n in quantizers}
    make_quant3(model, quantizers)
    qlayers = find_layers(model, [Quant3Linear])
    print('Packing ...')
    for name in qlayers:
        print(name)
        quantizers[name] = quantizers[name].cpu()
        qlayers[name].pack(layers[name], quantizers[name].scale, quantizers[name].zero)
    print('Done.')
    return model


if __name__ == '__main__':
    import argparse
    from datautils import *

    parser = argparse.ArgumentParser()

    parser.add_argument(
        'model', type=str,
        help='LlaMa model to load; pass location of hugginface converted checkpoint.'
    )
    parser.add_argument(
        'dataset', type=str, choices=['wikitext2', 'ptb', 'c4'],
        help='Where to extract calibration data from.'
    )
    parser.add_argument(
        '--seed',
        type=int, default=0, help='Seed for sampling the calibration data.'
    )
    parser.add_argument(
        '--nsamples', type=int, default=128,
        help='Number of calibration data samples.'
    )
    parser.add_argument(
        '--percdamp', type=float, default=.01,
        help='Percent of the average Hessian diagonal to use for dampening.'
    )
    parser.add_argument(
        '--nearest', action='store_true',
        help='Whether to run the RTN baseline.'
    ) 
    parser.add_argument(
        '--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],
        help='#bits to use for quantization; use 16 for evaluating base model.'
    )
    parser.add_argument(
        '--groupsize', type=int, default=-1,
        help='Groupsize to use for quantization; default uses full row.'
    )
    parser.add_argument(
        '--sym', action='store_true',
        help='Whether to perform symmetric quantization.'
    )
    parser.add_argument(
        '--save', type=str, default='',
        help='Save quantized checkpoint under this name.'
    )
    parser.add_argument(
        '--new-eval', action='store_true',
        help='Whether to use the new PTB and C4 eval.'
    )
    parser.add_argument(
        '--act-order', action='store_true',
        help='Whether to apply the activation order GPTQ heuristic'
    )
    parser.add_argument(
        '--true-sequential', action='store_true',
        help='Whether to run in true sequential model.'
    )
    parser.add_argument(
        '--static-groups', action='store_true',
        help='Whether to use static groups; recommended when using `--actorder` for more efficient inference.'
    )
    parser.add_argument(
        '--bits-config', type=str, default='main_2',
        help='Bit configuration string for different experts'
    )

    args = parser.parse_args()

    model = get_llama(args.model)
    model.eval()

    dataloader, testloader = get_loaders(
        args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
    )

    if args.wbits < 16 and not args.nearest:
        tick = time.time()
        quantizers = llama_sequential(model, dataloader, DEV, args.bits_config)
        print(time.time() - tick)

    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False) 
    datasets = ['wikitext2', 'ptb', 'c4'] 
    if args.new_eval:
        import lm_eval
        from lm_eval import tasks, simple_evaluate
        from lm_eval.models.huggingface import HFLM
        from lm_eval.tasks import initialize_tasks

        def pattern_match(patterns, source_list):
            task_names = set()
            for pattern in patterns:
                for matching in fnmatch.filter(source_list, pattern):
                    task_names.add(matching)
            return list(task_names)


        def update_results(results, new_result):
            for key, value in new_result.items():
                if key in results:
                    results[key].update(value)
                else:
                    results.update({key: value})

        results = {}
        # task_list = ["boolq", "piqa","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa", "copa"]
        task_list = ['copa']
        # task_names = pattern_match(task_list, tasks.TaskManager().all_tasks)
        initialize_tasks(verbosity="ERROR")
        lm = HFLM(
            pretrained=model,
            backend="causal",
            device="cuda",
            batch_size=32,
            tokenizer=tokenizer,
            max_lengt=2048,
        )
        t_results = simple_evaluate(
            model,
            tasks=task_list,
            num_fewshot=0,
            batch_size=32,
        )
        update_results(results, t_results)
        print(lm_eval.utils.make_table(results))

    if args.save:
        
        model.save_pretrained(args.save) 
        
        tokenizer.save_pretrained(args.save)
      

