import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

from lm_eval.base import BaseLM


class LlamaLM(BaseLM):
    def __init__(
        self,
        device="cuda",
        pretrained="huggyllama/llama-7b",
        revision="main",
        subfolder=None,
        tokenizer=None,
        batch_size=1,
        load_8bit=True,
    ):
        super().__init__()

        assert isinstance(device, str)
        assert isinstance(pretrained, str)
        assert isinstance(batch_size, int)
        self.batch_size_per_gpu = batch_size

        if device:
            if device not in ["cuda", "cpu"]:
                device = int(device)
            self._device = torch.device(device)
            print(f"Using device '{device}'")
        else:
            print("Device not specified")
            print(f"Cuda Available? {torch.cuda.is_available()}")
            self._device = (
                torch.device("cuda")
                if torch.cuda.is_available()
                else torch.device("cpu")
            )

        # TODO: update this to be less of a hack once subfolder is fixed in HF
        revision = revision + ("/" + subfolder if subfolder is not None else "")

        if load_8bit:
            self.model = LlamaForCausalLM.from_pretrained(
                pretrained, revision=revision, device_map="auto", load_in_8bit=True
            )
        self.model.eval()

        self.tokenizer = LlamaTokenizer.from_pretrained(
            pretrained if tokenizer is None else tokenizer,
            revision=revision,
        )
        # pruned_dict = torch.load(pretrained, map_location='cpu')
        # self.tokenizer, self.model = pruned_dict['tokenizer'], pruned_dict['model']
        # from slicegpt import data_utils, gpu_utils, hf_utils, utils
        # import os
        # model_adapter, self.tokenizer = hf_utils.load_sliced_model(
        #     "meta-llama/Llama-2-7b-hf",
        #     pretrained,
        #     sparsity=0.3,
        #     token=os.getenv('HF_TOKEN', None),
        #     round_interval=1,
        # )
        # self.model = model_adapter.model
        self.vocab_size = len(self.tokenizer)

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        try:
            return self.model.config.n_ctx
        except AttributeError:
            return self.model.config.max_position_embeddings

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        return self._device

    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model
        """
        with torch.no_grad():
            return self.model(inps)[0]

    def _model_generate(self, context, max_length, eos_token_id):
        return self.model.generate(
            context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
        )
