import torch

import transformers
from transformers import AutoTokenizer
import sys
import os
sys.path.append(os.path.join(os.getcwd(), "mamba_peft/src/"))
from mamba_peft.src.peft import PeftModel
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import MambaForCausalLM

from lm_eval.api.model import LM
from lm_eval.models.huggingface import HFLM
from lm_eval.api.registry import register_model
from lm_eval.__main__ import cli_evaluate
import lm_eval
from lm_eval.models.utils import stop_sequences_criteria


@register_model("MambaPEFT")
class MambaEvalWrapper(HFLM):

    AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM

    def __init__(self,
                 pretrained="state-spaces/mamba-130m-hf",
                 peft_weights=None,
                 max_length=2048,
                 batch_size=None,
                 device="cuda",
                 dtype=torch.float32,
                 trust_remote_code=False):
        self.peft_weights = peft_weights
        super().__init__(pretrained=pretrained,
                       tokenizer="EleutherAI/gpt-neox-20b",
                       max_length=max_length,
                       dtype=dtype,
                       trust_remote_code=trust_remote_code)

        self._batch_size = int(batch_size) if batch_size is not None else 64
        self._max_length = max_length
        self._device = torch.device(device)

        self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.vocab_size = self.tokenizer.vocab_size


    def _create_model(
        self,
        pretrained: str,
        dtype = "float32",
        # no `parallelize=True` options
        # no PEFT and quantization options
        # Mamba does not support arbitrary HF from_pretrained() args
        **kwargs,
    ) -> None:

        model = MambaForCausalLM.from_pretrained(
            pretrained,
            load_in_8bit=False,
            torch_dtype=torch.float16,
            trust_remote_code=True
        )

        self._model = PeftModel.from_pretrained(
            model,
            self.peft_weights,
            torch_dtype=torch.float32,
        )
        print(model)
        self._model.config.use_cache = False # Not fully implemented yet
        self._model.float()
        self._model.to(self._device)

    @property
    def batch_size(self):
        return self._batch_size

    def _model_generate(self, context, max_length, stop, **generation_kwargs):
        # temperature = 0.0 if not set
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
        do_sample = generation_kwargs.get("do_sample", None)

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
        # build stopping criteria
        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, context.shape[1], context.shape[0]
        )
        return self.model.generate(
            input_ids=context,
            max_length=max_length,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=False,
            **generation_kwargs,
        )

    # def _model_generate(self, context, max_length, stop, **generation_kwargs):
    #     raise NotImplementedError()

    # def _model_generate(self, context, max_length, stop, **generation_kwargs):
    #     # Remove problematic arguments for Mamba models
    #     remove_args = ["attention_mask"]
    #     for key in remove_args:
    #         if key in generation_kwargs:
    #             generation_kwargs.pop(key)
    #
    #     # Set up stopping criteria using the lm_eval utilities
    #     stopping_criteria = lm_eval.models.utils.stop_sequences_criteria(
    #         self.tokenizer,
    #         stop,
    #         context.shape[1],
    #         context.shape[0],
    #     )
    #
    #     # Handle temperature settings and do_sample
    #     temperature = generation_kwargs.get("temperature", 0.0)
    #     do_sample = generation_kwargs.get("do_sample", None)
    #
    #     # If temperature is 0.0 or do_sample is False, ensure proper greedy decoding
    #     if temperature == 0.0 or do_sample is False:
    #         generation_kwargs["do_sample"] = False
    #         generation_kwargs["top_k"] = 1
    #         generation_kwargs["top_p"] = 0.0
    #         generation_kwargs["min_p"] = 0.0
    #         # The model implementation will remove temperature if needed
    #     else:
    #         generation_kwargs["do_sample"] = True
    #         generation_kwargs["temperature"] = temperature
    #
    #     # Generate text with the model
    #     return self._model.generate(
    #         input_ids=context,
    #         max_length=max_length,
    #         stopping_criteria=stopping_criteria,
    #         pad_token_id=self.tokenizer.pad_token_id,
    #         use_cache=True,
    #         **generation_kwargs,
    #     )

    # def _model_generate(
    #         self,
    #         context,
    #         **kwargs,
    # ):
    #     output = self.model.generate(
    #         input_ids=context,
    #         use_cache=False,
    #         **kwargs,
    #     )
    #     return output


if __name__ == "__main__":
    cli_evaluate()