import itertools
import logging
import os
from dataclasses import dataclass
from typing import Iterable, Iterator, overload

# import deepspeed
import torch
from transformers import BatchEncoding, PreTrainedModel
from transformers.utils import ModelOutput

from lib_llm.ops.batch_size import get_batch_size

from ._batch_mapping import map_to_model_batches, remap_to_input_batches


# from accelerate import PartialState


logger = logging.getLogger(__name__)
HAS_CUDA = torch.cuda.is_available()

USES_DEEPSEED = "LOCAL_RANK" in os.environ
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "-1"))
INFERENCE_DEVICE = torch.device(
    f"cuda:{LOCAL_RANK}"
    if LOCAL_RANK > 0 and HAS_CUDA
    else ("cuda" if HAS_CUDA else "cpu")
)
CPU_DEVICE = torch.device("cpu")


@dataclass
class PredictionConfig:
    """Configuration for prediction.

    Attributes:
        local_rank: The index of the GPU to use on this machine.
        trim_last_token (bool): Whether to omit the last token of each
            sequence. This is useful when only computing sequence
            probabilities, i.e. when no next-token prediction is
            required.
        batch_size: The batch size to use for prediction. If < 0,
            then the maximum batch size for the given model and the
            current environment is inferred.
        offload_outputs_to_cpu: Whether to move outputs from the GPU to
            the CPU after inference, to make space for consecutive
            inference calls.
    """

    trim_last_token: bool = True
    batch_size: int = -1
    offload_outputs_to_cpu: bool = True


@overload
def predict(
    model: PreTrainedModel,
    data: BatchEncoding,
    config: PredictionConfig = PredictionConfig(),
) -> ModelOutput:
    ...


@overload
def predict(
    model: PreTrainedModel,
    data: Iterable[BatchEncoding],
    config: PredictionConfig = PredictionConfig(),
) -> Iterator[ModelOutput]:
    ...


def predict(
    model: PreTrainedModel,
    data,
    config: PredictionConfig = PredictionConfig(),
):
    """Performs efficient inference for model inputs.

    To improve the efficiency of inference, this function performs
    two optimizations:
    - It rebatches the inputs to be of the largest size that can fit
    into the model in the current environment to maximize throughput.
    - It tries to reduce the number of inputs fed to the model by
    compressing duplicate inputs into one. This is especially useful
    when setting config.trim_last_token to True, because then all
    sequences that are identical up to the last token will be compressed
    into one. Computing outputs for the last token is not necessary when
    next-token prediction is not needed.

    Args:
        model: The model to use for inference.
        data: The inputs to the model.
        config: The configuration for inference.
    Returns:
        The outputs of the model corresponding to the input batches.
    """
    if isinstance(data, BatchEncoding):
        return next(_iter_predict(model, [data], config))
    else:
        return _iter_predict(model, data, config)


def _iter_predict(
    model: PreTrainedModel,
    data: Iterable[BatchEncoding],
    config: PredictionConfig,
) -> Iterator[ModelOutput]:
    batch_size_iter, data_iter = itertools.tee(iter(data), 2)
    if config.batch_size < 0:
        first_batch = next(batch_size_iter)
        sequence_length = first_batch.input_ids.shape[1]
        batch_size = get_batch_size(
            model,
            sequence_length,
            local_rank=0 if LOCAL_RANK < 0 else LOCAL_RANK,
            action="inference",
        )
    else:
        batch_size = config.batch_size

    # if USES_DEEPSEED:
    #     # ds_engine = deepspeed.init_inference(
    #     #     model,
    #     #     # dtype=torch.half,
    #     #     dtype=torch.bfloat16,
    #     #     replace_with_kernel_inject=False,
    #     # )
    #     # model = ds_engine.module
    # else:
    #     # model = model.to(INFERENCE_DEVICE)
    #     model.eval()
    model.eval()

    with torch.no_grad():
        output_batch_generator = remap_to_input_batches()

        for model_batch, input_to_model_batch_indices in map_to_model_batches(
            data_iter,
            model_batch_size=batch_size,
            trim_last_column=config.trim_last_token,
        ):
            model_batch = model_batch.to(INFERENCE_DEVICE)
            if HAS_CUDA and model.device == CPU_DEVICE:
                model = model.to(INFERENCE_DEVICE)  # type: ignore[reportGeneralTypeIssues]

            # d1 = model_batch.input_ids.shape[0]
            # d2 = model_batch.input_ids.shape[1]
            # print(
            #     f"batch size: {d1} x {d2} = {d1 * d2}"
            # )
            output = model(**model_batch)
            # Move the outputs to the CPU, else they will accumulate in GPU
            # memory and eventually cause an OOM error.
            # TODO: don't move the outputs if all inputs are already
            # processed, i.e. if no more additional CPU memory is needed.
            output = _move_output_to_device(output, "cpu")
            yield from output_batch_generator(
                output,
                input_to_model_batch_indices,
            )


def _move_output_to_device(
    model_output: ModelOutput,
    device: str,
) -> ModelOutput:
    """
    Move all tensors in a Huggingface ModelOutput object to the
    specified device, including nested structures like tuples.

    Parameters:
    - model_output: The ModelOutput object containing tensors.
    - device: Target device e.g. "cpu", "cuda:0".

    Returns:
    - ModelOutput object with tensors moved to the specified device.
    """
    output = ModelOutput()
    for attr in dir(model_output):
        value = getattr(model_output, attr)

        if torch.is_tensor(value):
            output[attr] = value.to(device).detach().float()
        elif isinstance(value, tuple) and all(
            torch.is_tensor(t) for t in value
        ):
            output[attr] = tuple(t.to(device).detach().float() for t in value)

    del model_output
    # Free up GPU memory.
    torch.cuda.empty_cache()

    return output
