import logging
import math
from typing import Literal, cast

import torch
from transformers import PreTrainedModel


logger = logging.getLogger(__name__)

HAS_CUDA = torch.cuda.is_available()
A40_GPU_MEMORY = 47658106880


Action = Literal["inference", "training"]


# noqa: C901
def get_batch_size(
    model: PreTrainedModel,
    sequence_length: int,
    action: Action,
    local_rank: int,
) -> int:
    """
    Get the maximum batch size for a given model and sequence length in
    the current environment.

    Args:
        model (:obj:`PreTrainedModel`):
            The model to get the batch size for.
        sequence_length (:obj:`int`):
            The length of the token sequences passed to the model.

    Returns:
        :obj:`int`: The maximum batch size for the given model and sequence
        on the current environment.
    """
    if HAS_CUDA:
        gpu_memory = torch.cuda.get_device_properties(local_rank).total_memory
    else:
        gpu_memory = 16 * 1024 * 1024 * 1024
    # The values below are tested on a A40 GPU with 46 GB of memory.
    # We scale the batch size linearly with the GPU memory.
    # TODO: this assumption might be too conservative
    multiplier = gpu_memory / A40_GPU_MEMORY

    # TODO: use the action
    model_path = model.name_or_path
    if model_path is None:
        raise ValueError(
            "Cannot determine the batch size, no model path or name set."
        )
    model_id = cast(str, model_path).split("/")[-1]
    if model_id.startswith("pythia"):
        if model_id.startswith("pythia-12b"):
            batch_size = math.floor((4 * 1024) / sequence_length)
        elif model_id.startswith("pythia-6.9b"):
            batch_size = math.floor((8 * 1024) / sequence_length)
        elif model_id.startswith("pythia-2.8b"):
            batch_size = math.floor((16 * 1024) / sequence_length)
        elif model_id.startswith("pythia-1.4b"):
            batch_size = math.floor((24 * 1024) / sequence_length)
        elif model_id.startswith("pythia-1b"):
            # Tested on a 40GB A40 GPU, 48 runs out of memory.
            # 64 might work with deepspeed inference
            batch_size = math.floor((28 * 1024) / sequence_length)
        else:
            batch_size = math.floor((64 * 1024) / sequence_length)
    elif model_id.startswith("Llama-2"):
        if model_id.startswith("Llama-2-70b"):
            batch_size = math.floor((2 * 1024) / sequence_length)
        elif model_id.startswith("Llama-2-13b"):
            if sequence_length > 1024:
                batch_size = math.floor((12 * 1024) / sequence_length)
            else:
                # 18 might also work
                batch_size = math.floor((16 * 1024) / sequence_length)
        elif model_id.startswith("Llama-2-7b"):
            batch_size = math.floor((24 * 1024) / sequence_length)
        else:
            raise ValueError(f"Unknown Llama2 model: {model_id}")
    elif model_id.startswith("gpt2"):
        if model_id.startswith("gpt2-xl") or model_id.startswith("gpt2-1.5b"):
            batch_size = math.floor((32 * 1024) / sequence_length)
        elif model_id.startswith("gpt2-large"):
            batch_size = math.floor((48 * 1024) / sequence_length)
        elif model_id.startswith("gpt2-medium"):
            batch_size = math.floor((64 * 1024) / sequence_length)
        if model_id.startswith("gpt2") or model_id.startswith("gpt2-124m"):
            batch_size = math.floor((80 * 1024) / sequence_length)
        else:
            raise ValueError(f"Unknown GPT-2 model: {model_id}")
    elif model_id.startswith("phi"):
        if model_id == "phi-1":
            batch_size = math.floor((28 * 1024) / sequence_length)
        elif model_id == "phi-1_5":
            batch_size = math.floor((28 * 1024) / sequence_length)
        elif model_id == "phi-2":
            batch_size = math.floor((22 * 1024) / sequence_length)
        else:
            raise ValueError(f"Unknown Phi model: {model_id}")
    elif model_id.startswith("opt"):
        if model_id == "opt-350m":
            batch_size = math.floor((48 * 1024) / sequence_length)
        else:
            raise ValueError(f"Unknown OPT model: {model_id}")
    else:
        raise NotImplementedError(
            f"Cannot determine the batch size for model {model_id}."
        )
    return int(batch_size * multiplier)
