import time

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import LlamaForCausalLM

from pgz_prune import *
from modelutils import *
from datautils import *
from quant import *

class LLAMAClass(BaseLM):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = args.model
        self.batch_size_per_gpu = args.batch_size

        self.model = LlamaForCausalLM.from_pretrained(self.model_name, torch_dtype='auto')
        self.model.eval()
        self.seqlen = 2048
        self.tokenizer = get_tokenizer(self.model_name)
        self.vocab_size = self.tokenizer.vocab_size

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        return 2048

    @property
    def max_gen_toks(self):
        print('max_gen_toks fn')
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self.dev

    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call
        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model
        """
        with torch.no_grad():
            return self.model(inps)[0][:, :, :32000]

    @torch.no_grad()
    def _model_logits_on_dataset(self, dataset_inps):
        dataset_logits = []
        nsamples = len(dataset_inps)

        dev = self.device

        model = self.model

        print('Evaluation...')

        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False
        layers = self.model.model.layers

        self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.dev)
        self.model.model.norm = self.model.model.norm.to(self.dev)
        layers[0] = layers[0].to(self.dev)

        dtype = next(iter(self.model.parameters())).dtype
        inps = []
        outs = []

        for batch_idx, batch in enumerate(dataset_inps):
            inps.append(torch.zeros(
                (batch.shape[1], self.model.config.hidden_size), dtype=dtype,
            ))
            outs.append(torch.zeros(
                (batch.shape[1], self.model.config.hidden_size), dtype=dtype,
            ))

        cache = {'i': 0, 'attention_masks': []}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache['i']] = inp
                cache['i'] += 1
                cache['attention_masks'].append(kwargs['attention_mask'].detach().cpu())
                raise ValueError

        layers[0] = Catcher(layers[0])
        for i in range(nsamples):
            batch = dataset_inps[i].to(dev)
            try:
                self.model(batch)
            except ValueError:
                pass
        layers[0] = layers[0].module

        layers[0] = layers[0].cpu()
        self.model.model.embed_tokens = self.model.model.embed_tokens.cpu()
        self.model.model.norm = self.model.model.norm.cpu()
        torch.cuda.empty_cache()

        attention_masks = cache['attention_masks']

        for i in range(len(layers)):
            print(i)
            layer = layers[i].to(self.dev)

            if self.args.gmp:
                subset = find_layers(layer)
                for name in subset:
                    W = subset[name].weight.data
                    thresh = torch.sort(torch.abs(W.flatten()))[0][
                        int(W.numel() * self.args.sparsity)
                    ]
                    W.data[torch.abs(W.data) <= thresh] = 0

            for j in range(nsamples):
                outs[j] = layer(inps[j].to(self.dev),
                                attention_mask=attention_masks[j].to(self.dev))[0].detach().cpu()

            layers[i] = layer.cpu()
            del layer
            torch.cuda.empty_cache()
            inps, outs = outs, inps

        if self.model.model.norm is not None:
            self.model.model.norm = self.model.model.norm.to(self.dev)
        self.model.lm_head = self.model.lm_head.to(self.dev)

        dataset_logits = []
        for i in range(nsamples):
            hidden_states = inps[i].unsqueeze(0).to(self.dev)
            if self.model.model.norm is not None:
                hidden_states = self.model.model.norm(hidden_states)
            lm_logits = self.model.lm_head(hidden_states)
            batch_logits = F.log_softmax(lm_logits[0][:, :, :32000], dim=-1).cpu()
            dataset_logits.append(batch_logits)

        self.model.config.use_cache = use_cache
        return dataset_logits

    @torch.no_grad()
    def llama_sequential(self, dataloader):
        print("Starting...")

        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False
        layers = self.model.model.layers

        self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.dev)
        self.model.model.norm = self.model.model.norm.to(self.dev)
        layers[0] = layers[0].to(self.dev)

        dtype = next(iter(self.model.parameters())).dtype
        inps = torch.zeros(
            (self.args.nsamples, self.seqlen, self.model.config.hidden_size),
            dtype=dtype,
            device=self.dev
        )
        cache = {"i": 0, "attention_mask": None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache["i"]] = inp
                cache["i"] += 1
                cache["attention_mask"] = kwargs["attention_mask"]
                raise ValueError

        layers[0] = Catcher(layers[0])
        for batch in dataloader:
            try:
                self.model(batch[0].to(self.dev))
            except ValueError:
                pass
        layers[0] = layers[0].module

        layers[0] = layers[0].cpu()
        self.model.model.embed_tokens = self.model.model.embed_tokens.cpu()
        self.model.model.norm = self.model.model.norm.cpu()
        torch.cuda.empty_cache()

        fouts = torch.zeros_like(inps)
        attention_mask = cache["attention_mask"]

        print("Ready.")
        quantizers = {}
        for i in range(1):
        # for i in range(len(layers)):
            layer = layers[i].to(self.dev)
            full = find_layers(layer)

            if self.args.true_sequential:
                sequential = [
                    ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
                    ["self_attn.o_proj"],
                    ["mlp.up_proj", "mlp.gate_proj"],
                    ["mlp.down_proj"],
                ]
            else:
                sequential = [list(full.keys())]

            for names in sequential:
                subset = {n: full[n] for n in names}

                gpts = {}
                for name in subset:
                    gpts[name] = PGZ(subset[name], self.args)

                def add_batch(name, nsamples):
                    def tmp(_, inps, out):
                        gpts[name].add_batch(inps[0].data, nsamples, out.data)
                    return tmp

                handles = []
                for name in subset:
                    handles.append(subset[name].register_forward_hook(add_batch(name, self.args.nsamples)))

                if i == 0:
                    for j in range(self.args.nsamples):
                        fouts[j] = inps[j]

                for j in range(self.args.nsamples):
                    fouts[j] = layer(fouts[j].unsqueeze(0), attention_mask=attention_mask)[0]

                for h in handles:
                    h.remove()

                for name in subset:
                    print(i, name)
                    print("Pruning ...")
                    gpts[name].fastprune()
                    gpts[name].free()

            if i % len(layers) == 0: # len(layers) - 1:
                for dataset in ["wikitext2", "ptb"]:
                    dataloader, testloader = get_loaders(
                        dataset,
                        seed=self.args.seed,
                        model=self.args.model,
                        seqlen=self.seqlen
                    )
                    print("Dataset:", dataset)
                    self.llama_eval(testloader, dataset)

            layers[i] = layer.cpu()
            del layer
            del gpts
            torch.cuda.empty_cache()

        self.model.config.use_cache = use_cache
        return quantizers

    @torch.no_grad()
    def llama_eval(self, testenc, dataset: str):
        print("Evaluating ...")

        testenc = testenc.input_ids
        nsamples = testenc.numel() // self.seqlen

        use_cache = self.model.config.use_cache
        self.model.config.use_cache = False
        layers = self.model.model.layers

        self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.dev)
        layers[0] = layers[0].to(self.dev)

        dtype = next(iter(self.model.parameters())).dtype
        inps = torch.zeros(
            (nsamples, self.seqlen, self.model.config.hidden_size),
            dtype=dtype,
            device=self.dev
        )
        cache = {"i": 0, "attention_mask": None}

        class Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module

            def forward(self, inp, **kwargs):
                inps[cache["i"]] = inp
                cache["i"] += 1
                cache["attention_mask"] = kwargs["attention_mask"]
                raise ValueError

        layers[0] = Catcher(layers[0])
        for i in range(nsamples):
            batch = testenc[:, (i * self.seqlen) : ((i + 1) * self.seqlen)].to(self.dev)
            try:
                self.model(batch)
            except ValueError:
                pass
        layers[0] = layers[0].module

        layers[0] = layers[0].cpu()
        self.model.model.embed_tokens = self.model.model.embed_tokens.cpu()
        torch.cuda.empty_cache()

        outs = torch.zeros_like(inps)
        attention_mask = cache["attention_mask"]

        for i in range(len(layers)):
            print(i)
            layer = layers[i].to(self.dev)

            if self.args.gmp:
                subset = find_layers(layer)
                for name in subset:
                    W = subset[name].weight.data
                    thresh = torch.sort(torch.abs(W.flatten()))[0][
                        int(W.numel() * self.args.sparsity)
                    ]
                    W.data[torch.abs(W.data) <= thresh] = 0

            for j in range(nsamples):
                outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
            layers[i] = layer.cpu()
            del layer
            torch.cuda.empty_cache()
            inps, outs = outs, inps

        if self.model.model.norm is not None:
            self.model.model.norm = self.model.model.norm.to(self.dev)
        self.model.lm_head = self.model.lm_head.to(self.dev)

        testenc = testenc.to(self.dev)
        nlls = []
        for i in range(nsamples):
            hidden_states = inps[i].unsqueeze(0)
            if self.model.model.norm is not None:
                hidden_states = self.model.model.norm(hidden_states)
            lm_logits = self.model.lm_head(hidden_states)
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = testenc[:, (i * self.seqlen) : ((i + 1) * self.seqlen)][:, 1:]
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            neg_log_likelihood = loss.float() * self.seqlen
            nlls.append(neg_log_likelihood)
        ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * self.seqlen))
        print(f"Perplexity: {ppl.item():3f}")

        self.model.config.use_cache = use_cache

LLAMA = LLAMAClass
