import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import LlamaForCausalLM

from sparsellama import *
from modelutils import *
from datautils import get_tokenizer
from quant import *

class LLAMAClass(BaseLM):
    def __init__(self, args):
        super().__init__()
        self.args = args
        print("self.args: {}\n".format(self.args))
        self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = args.model
        self.batch_size_per_gpu = args.batch_size

        self.model = LlamaForCausalLM.from_pretrained(self.model_name, torch_dtype='auto')
        self.model.eval()
        self.seqlen = 2048
        self.tokenizer = get_tokenizer(self.model_name)
        self.vocab_size = self.tokenizer.vocab_size

    @torch.no_grad()
    def llama_sequential(self, dataloader):
        print("Starting...")

        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False
        layers = self.model.model.layers

        self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.dev)
        self.model.model.norm = self.model.model.norm.to(self.dev)
        layers[0] = layers[0].to(self.dev)

        dtype = next(iter(self.model.parameters())).dtype
        inps = torch.zeros(
            (self.args.nsamples, self.seqlen, self.model.config.hidden_size),
            dtype=dtype,
            device=self.dev
        )
        cache = {"i": 0, "attention_mask": None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache["i"]] = inp
                cache["i"] += 1
                cache["attention_mask"] = kwargs["attention_mask"]
                raise ValueError

        layers[0] = Catcher(layers[0])
        for batch in dataloader:
            try:
                self.model(batch[0].to(self.dev))
            except ValueError:
                pass
        layers[0] = layers[0].module

        layers[0] = layers[0].cpu()
        self.model.model.embed_tokens = self.model.model.embed_tokens.cpu()
        self.model.model.norm = self.model.model.norm.cpu()
        torch.cuda.empty_cache()

        fouts = torch.zeros_like(inps)
        attention_mask = cache["attention_mask"]

        print("Ready.")
        quantizers = {}
        for i in range(len(layers)):
            layer = layers[i].to(self.dev)
            full = find_layers(layer)

            if self.args.true_sequential:
                sequential = [
                    ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
                    ["self_attn.o_proj"],
                    ["mlp.up_proj", "mlp.gate_proj"],
                    ["mlp.down_proj"],
                ]
            else:
                sequential = [list(full.keys())]

            for names in sequential:
                subset = {n: full[n] for n in names}

                gpts = {}
                for name in subset:
                    gpts[name] = SparseLLAMA(subset[name], self.args)

                    if args.wbits < 16:
                        gpts[name].quantizer = Quantizer()
                        gpts[name].quantizer.configure(
                            args.wbits, perchannel=True, sym=False, mse=False
                        )

                def add_batch(name, nsamples):
                    def tmp(_, inps, out):
                        gpts[name].add_batch(inps[0].data, nsamples, out.data)
                    return tmp

                handles = []
                for name in subset:
                    handles.append(subset[name].register_forward_hook(add_batch(name, self.args.nsamples)))

                if i == 0:
                    for j in range(self.args.nsamples):
                        fouts[j] = inps[j]

                for j in range(self.args.nsamples):
                    fouts[j] = layer(fouts[j].unsqueeze(0), attention_mask=attention_mask)[0]

                for h in handles:
                    h.remove()

                for name in subset:
                    print(i, name)
                    print("Pruning ...")
                    gpts[name].fastprune()
                    gpts[name].free()

            layers[i] = layer.cpu()
            del layer
            del gpts
            torch.cuda.empty_cache()
        self.model.config.use_cache = use_cache
        return quantizers

    @torch.no_grad()
    def llama_eval(self, testenc, dataset: str):
        print("Evaluating ...")

        testenc = testenc.input_ids
        nsamples = testenc.numel() // self.seqlen

        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False
        layers = self.model.model.layers

        self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.dev)
        layers[0] = layers[0].to(self.dev)

        dtype = next(iter(self.model.parameters())).dtype
        inps = torch.zeros(
            (nsamples, self.seqlen, self.model.config.hidden_size),
            dtype=dtype,
            device=self.dev
        )
        cache = {"i": 0, "attention_mask": None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache["i"]] = inp
                cache["i"] += 1
                cache["attention_mask"] = kwargs["attention_mask"]
                raise ValueError

        layers[0] = Catcher(layers[0])
        for i in range(nsamples):
            batch = testenc[:, (i * self.seqlen) : ((i + 1) * self.seqlen)].to(self.dev)
            try:
                self.model(batch)
            except ValueError:
                pass
        layers[0] = layers[0].module

        layers[0] = layers[0].cpu()
        self.model.model.embed_tokens = self.model.model.embed_tokens.cpu()
        torch.cuda.empty_cache()

        outs = torch.zeros_like(inps)
        attention_mask = cache["attention_mask"]

        for i in range(len(layers)):
            print(i)
            layer = layers[i].to(self.dev)

            if args.gmp:
                subset = find_layers(layer)
                for name in subset:
                    W = subset[name].weight.data
                    thresh = torch.sort(torch.abs(W.flatten()))[0][
                        int(W.numel() * args.sparsity)
                    ]
                    W.data[torch.abs(W.data) <= thresh] = 0

            for j in range(nsamples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
            layers[i] = layer.cpu()
            del layer
            torch.cuda.empty_cache()
            inps, outs = outs, inps

        if self.model.model.norm is not None:
            self.model.model.norm = self.model.model.norm.to(self.dev)
        self.model.lm_head = self.model.lm_head.to(self.dev)

        testenc = testenc.to(self.dev)
        nlls = []
        for i in range(nsamples):
            hidden_states = inps[i].unsqueeze(0)
            if self.model.model.norm is not None:
                hidden_states = self.model.model.norm(hidden_states)
            lm_logits = self.model.lm_head(hidden_states)
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = testenc[:, (i * self.seqlen) : ((i + 1) * self.seqlen)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * self.seqlen
            nlls.append(neg_log_likelihood)
        ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * self.seqlen))
        print(f"Perplexity: {ppl.item():3f}")

        self.model.config.use_cache = use_cache

LLAMA = LLAMAClass

if __name__ == "__main__":
    import argparse
    from datautils import *

    parser = argparse.ArgumentParser()

    parser.add_argument("model", type=str, help="LlaMA model to load")
    parser.add_argument(
        "dataset",
        type=str,
        choices=["wikitext2", "ptb", "c4"],
        help="Where to extract calibration data from.",
    )
    parser.add_argument(
        "--seed", type=int, default=0, help="Seed for sampling the calibration data."
    )
    parser.add_argument(
        "--nsamples", type=int, default=320, help="Number of calibration data samples."
    )

    parser.add_argument("--sparsity", type=float, default=0.5, help="Target sparsity")
    parser.add_argument("--prunen", type=int, default=0, help="N for N:M pruning.")
    parser.add_argument("--prunem", type=int, default=0, help="M for N:M pruning.")
    parser.add_argument("--prune_ratio_per_iter", type=float, default=0.05, help="Pruning ratio per iteration.")
    parser.add_argument("--interval", type=int, default=256)
    parser.add_argument("--step", type=float, default=0.0008, help="learning rate. ")

    parser.add_argument(
        "--gmp", action="store_true", help="Whether to run the GMP baseline."
    )
    parser.add_argument(
        "--wbits", type=int, default=16, help="Whether to quantize as well."
    )
    parser.add_argument("--save", type=str, default="", help="Path to saved model.")
    parser.add_argument("--true-sequential", action="store_true", help="Whether to run in true sequential model.")

    args = parser.parse_args()

    args.batch_size = 1

    model = LLAMAClass(args)

    dataloader, testloader = get_loaders(
        args.dataset,
        nsamples=args.nsamples,
        seed=args.seed,
        model=args.model,
        seqlen=model.seqlen
    )

    if (args.sparsity or args.prunen) and not args.gmp:
        tick = time.time()
        model.llama_sequential(dataloader)
        for n, p in model.model.named_parameters():
            print(n, torch.mean((p == 0).float()))
            if 'down_proj' in n:
                break
        print(" sparse cost time: {} \n".format(time.time() - tick))

    for dataset in ["wikitext2", "ptb", "c4"]:
        dataloader, testloader = get_loaders(
            dataset,
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen
        )
        print("Dataset:", dataset)
        model.llama_eval(testloader, dataset)

    if args.save:
        model.model.save_pretrained(args.save)
