import copy
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import torch
import torch.nn.functional as F
import transformers
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import Collator, pad_and_concat
from tqdm import tqdm

from oe_eval.components.requests import GenerateUntilRequest, LoglikelihoodRequest
from oe_eval.utils import cut_at_stop_sequence

# Minimally modified version of model inference code from lm_eval, (models/huggingface.py)
# adding _verbose versions of various methods to return additional information

eval_logger = utils.eval_logger


class HFLM_Verbose(HFLM):
    def __init__(
        self,
        pretrained: Optional[Union[str, transformers.PreTrainedModel]] = None,
        device: Optional[str] = None,
        device_map_option: Optional[str] = "auto",
        dtype: Optional[Union[str, torch.dtype]] = None,
        **kwargs,
        # pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
        # backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
        # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
        # revision: Optional[str] = "main",
        # subfolder: Optional[str] = None,
        # tokenizer: Optional[
        #    Union[
        #        str,
        #        transformers.PreTrainedTokenizer,
        #        transformers.PreTrainedTokenizerFast,
        #    ]
        # ] = None,
        # truncation: Optional[bool] = False,
        # logits_cache: bool = True,
        # max_length: Optional[int] = None,
        # device: Optional[str] = "cuda",
        # dtype: Optional[Union[str, torch.dtype]] = "auto",
        # batch_size: Optional[Union[int, str]] = 1,
        # max_batch_size: Optional[int] = 64,
        # trust_remote_code: Optional[bool] = False,
        # use_fast_tokenizer: Optional[bool] = True,
        # add_bos_token: Optional[bool] = False,
        # prefix_token_id: Optional[int] = None,
        # arguments used for splitting a model across GPUs naively.
        # only used if `parallelize=True`.
        # parallelize: Optional[bool] = False,
        # device_map_option: Optional[str] = "auto",
        # max_memory_per_gpu: Optional[Union[int, str]] = None,
        # max_cpu_memory: Optional[Union[int, str]] = None,
        # offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
        # PEFT, delta weights and quantization options
        # peft: Optional[str] = None,
        # delta: Optional[str] = None,
        # autogptq: Optional[Union[bool, str]] = False,
        # **kwargs,
    ) -> None:
        if device is None:
            device = "cuda" if torch.cuda.device_count() > 0 else "cpu"
        if device_map_option is None:
            if torch.cuda.device_count() > 1:
                device_map_option = "balanced_low_0"
            elif torch.cuda.device_count() > 0:
                device_map_option = "auto"
        if "revision" in kwargs and kwargs["revision"] is None:
            # Hack to deal with Eleuther using "main" as a default for "revision"
            kwargs = kwargs.copy()
            del kwargs["revision"]
        if torch.cuda.device_count() > 1:
            kwargs["parallelize"] = True
        super().__init__(
            pretrained, device=device, dtype=dtype, device_map_option=device_map_option, **kwargs
        )

    def loglikelihood_rolling_verbose(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[float]:
        loglikelihoods = []

        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size

        for (string,) in tqdm(
            [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
        ):
            rolling_token_windows = list(
                map(
                    utils.make_disjoint_window,
                    utils.get_rolling_token_windows(
                        token_list=self.tok_encode(string),
                        prefix_token=self.prefix_token_id,
                        max_seq_len=self.max_length,
                        context_len=1,
                    ),
                )
            )

            # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
            rolling_token_windows = [(None,) + x for x in rolling_token_windows]

            pad_amnt = 0
            if self.world_size > 1:
                # We pad out the external document-level iterator so the inner iterator doesn't hang
                mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
                gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()

                pad_amnt = max(gathered) - gathered[self.rank]
                if pad_amnt > 0:
                    rolling_token_windows += pad_amnt * [rolling_token_windows[0]]

            string_nll = self._loglikelihood_tokens(
                requests=rolling_token_windows,
                disable_tqdm=True,
                override_bs=adaptive_batch_size,
            )

            if (self.world_size > 1) and (pad_amnt > 0):
                string_nll = [x[0] for x in string_nll[:-pad_amnt]]
            else:
                # discard is_greedy
                string_nll = [x[0] for x in string_nll]

            string_nll = sum(string_nll)
            loglikelihoods.append(string_nll)

        return loglikelihoods

    def loglikelihood_verbose(
        self,
        requests: List[LoglikelihoodRequest],
        override_bs: Optional[int] = None,
        disable_tqdm: bool = False,
    ) -> List[dict]:
        new_reqs = []
        for req in requests:
            context = cast(str, req.context)
            continuation = req.continuation
            if context == "":
                # BOS or EOS as context
                context_enc, continuation_enc = (
                    [self.prefix_token_id],
                    self.tok_encode(continuation),
                )
            else:
                context_enc, continuation_enc = self._encode_pair(context, continuation)
            new_reqs.append(((context, continuation), context_enc, continuation_enc))
        return self._loglikelihood_tokens_verbose(
            new_reqs, disable_tqdm=disable_tqdm, override_bs=override_bs
        )

    def _loglikelihood_tokens_verbose(
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
        override_bs: Optional[int] = None,
        verbose_token_logits: bool = False,
    ) -> List[dict]:
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        verbose_res = []

        def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key for the sorted method"""
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = req[1] + req[2]
            return -len(toks), tuple(toks)

        def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key to group and lookup one-token continuations"""
            # Use with group_by="contexts" (optional)"
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by=(
                "contexts"
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM and self.logits_cache
                else None
            ),
            group_fn=_lookup_one_token_cont,
        )

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else override_bs if override_bs is not None else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs
            else None
        )

        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
        for chunk in chunks:
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = 0
            padding_len_cont = 0
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works (illustrated on a causal decoder-only setup):
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
                        device=self.device,
                    )
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
                        device=self.device,
                    )
                    (inplen,) = inp.shape

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

                    cont = torch.tensor(
                        (continuation_enc)[-self.max_length :],
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
                        dtype=torch.long,
                        device=self.device,
                    )
                    (contlen,) = cont.shape

                    conts.append(cont)

                    padding_len_cont = (
                        max(padding_len_cont, contlen) if padding_len_cont is not None else contlen
                    )

                padding_len_inp = (
                    max(padding_len_inp, inplen) if padding_len_inp is not None else inplen
                )

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)

            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                batched_inps = pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
                batched_inps = pad_and_concat(padding_len_inp, inps)  # [batch, padding_len_inp]
                batched_conts = pad_and_concat(padding_len_cont, conts)  # [batch, padding_len_cont]
                batched_encoder_mask = pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }

            multi_logits = F.log_softmax(
                self._model_call(batched_inps, **call_kwargs), dim=-1
            )  # [batch, padding_length (inp or cont), vocab]

            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
                # take only logits in the continuation
                # (discard context toks if decoder-only ; discard right-padding)
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
                ctx_len = (
                    inplen + (logits.shape[0] - padding_len_inp)
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
                logits = logits.unsqueeze(0)  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
                for request_str, cont_toks, logits in re_ord.get_cache(
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(
                        0
                    )  # [1, seq]
                    max_equal = (greedy_tokens == cont_toks).all()

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    verbose_answer: Dict[str, Any] = {
                        "sum_logits": answer[0],
                        "num_tokens": len(cont_toks[0]),
                        "num_tokens_all": len(ctx_tokens) + len(cont_toks[0]),
                        "is_greedy": answer[1],
                    }
                    if verbose_token_logits:
                        instance_tokens = [
                            self.tok_decode(x, skip_special_tokens=False) for x in cont_toks
                        ]
                        verbose_answer["tokens"] = instance_tokens
                        verbose_answer["logits"] = logits.tolist()
                    verbose_res.append(verbose_answer)

                    self.cache_hook.add_partial("loglikelihood", request_str, answer)
                    pbar.update(1)

        pbar.close()

        return re_ord.get_original(verbose_res)

    # arguments = (ctx, deepcopy(self.config.generation_kwargs))
    def generate_until_verbose(
        self, requests: List[GenerateUntilRequest], disable_tqdm: bool = False
    ) -> List[dict]:
        # Convert to request_args used in original lm_eval generate_until for minimal diff
        request_args: List[Tuple[str, dict]] = []
        for request in requests:
            kwargs = request.generation_kwargs
            if kwargs is None:
                kwargs = {}
            else:
                kwargs = kwargs.copy()
            kwargs["until"] = request.stop_sequences
            request_args.append((cast(str, request.context), kwargs))
        res = []

        def _collate(req: Tuple[str, dict]):
            """Defines the key for the sorted method"""
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]

        pbar = tqdm(
            total=len(request_args),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running generate_until requests",
        )
        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
        # for each different set of kwargs, we execute all requests, by batch.
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size if adaptive_batch_size is not None else 0
        )
        batch_fn = (
            self._batch_scheduler if self.batch_size == "auto" and not adaptive_batch_size else None
        )

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            request_args,
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [until]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
                )
            # add EOS token to stop sequences
            eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
            if not until:
                until = [eos]
            else:
                until.append(eos)
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks
            truncate_context = kwargs.pop("truncate_context", True)

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                if truncate_context:
                    # max len for inputs = max length, minus room to generate the max new tokens
                    max_ctx_len = self.max_length - max_gen_toks
                    if max_ctx_len <= 0:
                        raise ValueError(
                            f"max_gen_toks ({max_gen_toks}) is greater than max_length ({self.max_length}), "
                            + '"consider using "truncate_context": False'
                        )
                else:
                    # If truncate_context is False we modify max_gen_toks instead, to keep full context
                    max_ctx_len = self.max_length
                    # The first context is the longest because of ordering above
                    max_batch_toks = len(self.tok_encode(contexts[0]))
                    max_gen_toks = XXXX-11(max_gen_toks, self.max_length - max_batch_toks)
                    max_gen_toks = max(max_gen_toks, 0)
                    if max_gen_toks == 0:
                        raise ValueError(
                            f"Effective max_gen_toks is 0 (max_batch_toks = {max_batch_toks}, max_ctx_len = {max_ctx_len}) "
                            + '"consider using "truncate_context": True'
                        )
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                list(contexts),  # oe-eval: XXXX
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)

            if "max_length" not in kwargs:
                kwargs["max_length"] = context_enc.shape[1] + max_gen_toks

            # perform batched generation, with args for logit scores
            kwargs["output_scores"] = True
            kwargs["return_dict_in_generate"] = True
            output = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )
            # Extract generated sequences and corresponding logits
            cont_toks_list = output["sequences"].tolist()
            gen_sequences = output["sequences"][:, context_enc.shape[1] :]

            # stack scores generated at each step
            scores = torch.stack(output["scores"], dim=1)  # shape [batch_size, seq_len, vocab_size]
            log_probs = F.log_softmax(scores, dim=-1)

            # collect scores/logits of the generated token
            gen_scores_list = torch.gather(scores, 2, gen_sequences[:, :, None]).squeeze(
                -1
            )  # shape [batch_size, seq_len]
            gen_log_probs_list = torch.gather(log_probs, 2, gen_sequences[:, :, None]).squeeze(-1)

            for cont_toks, gen_scores, gen_log_probs, context in zip(
                cont_toks_list, gen_scores_list, gen_log_probs_list, contexts
            ):
                # discard context + left-padding toks if using causal decoder-only LM
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    cont_toks = cont_toks[context_enc.shape[1] :]

                s_raw = self.tok_decode(cont_toks)
                s = cut_at_stop_sequence(s_raw, until)
                no_pad_len = len(cont_toks)
                for idx, tok in enumerate(cont_toks):
                    # Keep the first eot_token encountered
                    if tok == self.eot_token_id:
                        no_pad_len = idx + 1
                        break
                    if tok == self.tokenizer.pad_token_id:
                        no_pad_len = idx
                        break
                cont_toks_no_pad = cont_toks[:no_pad_len]

                logits = gen_log_probs[: len(cont_toks_no_pad)]
                sum_logits = logits.sum().item()

                res1 = {
                    "continuation": s,
                    "sum_logits": sum_logits,
                    "num_tokens": len(cont_toks_no_pad),
                    # "tokens": cont_toks_no_pad,
                    # "logits": logits.tolist(),
                }
                if s_raw != s:
                    res1["continuation_raw"] = s_raw

                res.append(res1)
                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)

        pbar.close()

        return res
