import torch
import transformers
from typing import Optional, Union
from lm_eval.base import BaseLM
import os

from .gpt2_custom import GPT, GPTConfig

os.environ['HF_HOME'] = '/scratch/07946/ss95332/huggingface'


# Set HF_HOME to change cache location
# os.environ['HF_HOME'] = '/scratch/07946/ss95332/huggingface'


def _get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
    """Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
    if isinstance(dtype, str) and dtype != "auto":
        # Convert `str` args torch dtype: `float16` -> `torch.float16`
        _torch_dtype = getattr(torch, dtype)
    else:
        _torch_dtype = dtype
    return _torch_dtype


class HFLM(BaseLM):
    _DEFAULT_MAX_LENGTH = 2048

    def __init__(
            self,
            device="cuda",
            pretrained="gpt2",
            revision="main",
            low_cpu_mem_usage=None,
            subfolder=None,
            tokenizer=None,
            batch_size=1,
            max_batch_size=512,
            max_length=None,
            load_in_8bit: Optional[bool] = False,
            trust_remote_code: Optional[bool] = False,
            dtype: Optional[Union[str, torch.dtype]] = "auto",
    ):
        super().__init__()

        # Initialize model
        if isinstance(pretrained, transformers.PreTrainedModel):
            self.model = pretrained
            self._device = self.model.device

            if tokenizer:
                assert isinstance(
                    tokenizer, transformers.PreTrainedTokenizer
                ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
                self.tokenizer = tokenizer
            else:
                # Get tokenizer
                model_name = self.model.name_or_path
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_name,
                    revision=revision,
                    trust_remote_code=trust_remote_code,
                )

        elif isinstance(pretrained, str):

            # Initialize device
            assert isinstance(device, str)
            device_list = set(
                ["cuda", "cpu"]
                + [f"cuda:{i}" for i in range(torch.cuda.device_count())]
            )
            if device and device in device_list:
                self._device = torch.device(device)
                print(f"Using device '{device}'")
            else:
                print("Device not specified")
                print(f"Cuda Available? {torch.cuda.is_available()}")
                self._device = (
                    torch.device("cuda")
                    if torch.cuda.is_available()
                    else torch.device("cpu")
                )
            revision = revision + ("/" + subfolder if subfolder is not None else "")

            '''
            # Initialize new model and tokenizer instances
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                pretrained,
                load_in_8bit=load_in_8bit,
                low_cpu_mem_usage=low_cpu_mem_usage,
                revision=revision,
                torch_dtype=_get_dtype(dtype),
                trust_remote_code=trust_remote_code,
            ).to(self.device)
            '''

            # ckpt_path = '/scratch/07946/ss95332/sophia/large/ckpt_half_init.pt'
            # ckpt_path = '/scratch/07946/ss95332/sophia/large-9m/ckpt_half.pt'
            # ckpt_path = '/scratch/07946/ss95332/sophia/small-9m/ckpt_half_init.pt'
            ckpt_path = '/scratch/07946/ss95332/sophia/medium/ckpt_ly16.pt'
            # ckpt_path = '/scratch/07946/ss95332/sophia/large/ckpt.pt'
            # ckpt_path = '/scratch/07946/ss95332/sophia/small-9m/ckpt_half.pt'
            # ckpt_path = '../Sophia/few_shot_ckpts/checkpoint_k5_23000.pt'
            # ckpt_path = '../Sophia/few_shot_ckpts/checkpoint_k5_70000.pt'

            print("LOADING ON: ", ckpt_path)

            # ckpt_path = '../Sophia/out/ckpt_9000.pt'
            checkpoint = torch.load(ckpt_path, map_location=device)
            model_args = checkpoint['model_args']
            # create the model
            gptconf = GPTConfig(**model_args)
            gptconf.n_layer = 16  # changing the layers of GPT
            model = GPT(gptconf)
            try:
                state_dict = checkpoint['model']
            except:
                try:
                    state_dict = checkpoint['net']
                except:
                    print("Could not load checkpoint")
                    exit()

            unwanted_prefix = '_orig_mod.'
            for k, v in list(state_dict.items()):
                if k.startswith(unwanted_prefix):
                    state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
            model.load_state_dict(state_dict)
            self.model = model
            self.model = model.to(device)
            model_args['vocab_size'] = 50257

            self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                tokenizer if tokenizer else pretrained,
                revision=revision,
                trust_remote_code=trust_remote_code,
            )

        else:
            raise TypeError(
                "Parameter pretrained should be of type str or transformers.PreTrainedModel"
            )

        self.model.eval()

        self.vocab_size = self.tokenizer.vocab_size

        # Validate batch_size
        assert isinstance(batch_size, (int, str))

        # setup for automatic batch size detection
        if str(batch_size).startswith("auto"):
            batch_size = batch_size.split(":")
            self.batch_size_per_gpu = batch_size[0]
            self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1
        else:
            self.batch_size_per_gpu = int(batch_size)
        self.max_batch_size = max_batch_size

        self._max_length = max_length

    @property
    def eot_token_id(self):
        # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
        return self.tokenizer.eos_token_id

    @property
    def max_length(self):
        if self._max_length:  # if max length manually set, return it
            return self._max_length
        seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
        for attr in seqlen_config_attrs:
            if hasattr(self.model.config, attr):
                return getattr(self.model.config, attr)
        if hasattr(self.tokenizer, "model_max_length"):
            if self.tokenizer.model_max_length == 1000000000000000019884624838656:
                return self._DEFAULT_MAX_LENGTH
            return self.tokenizer.model_max_length
        return self._DEFAULT_MAX_LENGTH

    @property
    def max_gen_toks(self):
        return 256

    @property
    def batch_size(self):
        # TODO: fix multi-gpu
        return self.batch_size_per_gpu  # * gpus

    @property
    def device(self):
        # TODO: fix multi-gpu
        return self._device

    def tok_encode(self, string: str):
        return self.tokenizer.encode(string, add_special_tokens=False)

    def tok_decode(self, tokens):
        return self.tokenizer.decode(tokens)

    def _model_call(self, inps):
        """
        inps: a torch tensor of shape [batch, sequence]
        the size of sequence may vary from call to call

        returns: a torch tensor of shape [batch, sequence, vocab] with the
        logits returned from the model
        """
        with torch.no_grad():
            outputs = self.model(inps, targets=0)[0]
            # outputs = self.model(inps)[0]
            return outputs

    def _model_generate(self, context, max_length, eos_token_id):
        '''
        generation_kwargs = {"do_sample": False, "max_length": max_length}
        if eos_token_id is not None:
            generation_kwargs["eos_token_id"] = eos_token_id
            generation_kwargs[
                "pad_token_id"
            ] = eos_token_id  # setting eos_token_id as pad token
        '''
        generation_kwargs = {"max_new_tokens": max_length - int(context.shape[-1])}
        generation = self.model.generate(context, **generation_kwargs)
        return generation


# for backwards compatibility
GPT2LM = HFLM
