


import argparse
import json
import sys

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from vllm import LLM

from hf_olmo import OLMoConfig, OLMoForCausalLM, OLMoTokenizerFast


def stdio_predictor_wrapper(predictor):
    
    for line in sys.stdin:
        line = line.rstrip()
        inputs = json.loads(line)
        assert isinstance(inputs, list)
        
        outputs = predictor.predict(inputs=inputs)
        
        outputs = [o for o in outputs]
        sys.stdout.write(f"{json.dumps(outputs)}\n")
        
        
        sys.stdout.flush()

    with open("inputs_file.txt", "w") as f:
        f.write(line)


class PredictWrapper:
    def __init__(self, pretrained_model_dir):
        device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir)
        self.tokenizer.padding_size = "left"
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        self.model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_dir
        )  
        self.model = self.model.to(device)

    def predict(self, inputs):
        inputs = self.tokenizer.batch_encode_plus(
            inputs,
            padding=True,
            return_tensors="pt",
        ).input_ids
        inputs = inputs.to(self.model.device)
        outputs = self.model.generate(inputs, max_new_tokens=256)
        outputs = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for output in outputs:
            yield output.strip()


class VLLMPredictWrapper:
    def __init__(self, pretrained_model_dir):
        self.model = LLM(pretrained_model_dir, gpu_memory_utilization=0.9, tensor_parallel_size=2)

    def predict(self, inputs):
        outputs = self.model.generate(inputs)
        for output in outputs:
            yield output.outputs[0].text.strip()


def get_args():
    parser = argparse.ArgumentParser(description="Run efficiency benchmark")
    parser.add_argument(
        "--pretrained-model",
        type=str,
        help="Path to the unquantized model / Name of the unquantized huggingface model.",
    )
    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = get_args()
    predictor = PredictWrapper(args.pretrained_model)
    
    stdio_predictor_wrapper(predictor)
