# Note: Usage: `VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 -m thinktime.utils.inference_tool_vllm`

import thinktime.vllm.chatts_vllm
from vllm import LLM, SamplingParams
from thinktime.utils.llm_utils_tools import LLMClientWithTools
import torch
import os
import json
from loguru import logger
import numpy as np


# CONFIG
EXP = 'thinktime_uts_reason'
MODEL_PATH = os.path.abspath('ckpt')
DATASET = f'./evaluation/dataset/uts_reason.jsonl'
WORKDIR = os.path.abspath('./')
NUM_GPUS_PER_PROCESS = 8


def answer_question_list(question_list, ts_list):
    answer_dict = {}
    llm_client = LLMClientWithTools(model_path=MODEL_PATH, tensor_parallel=NUM_GPUS_PER_PROCESS)
    answer_list = llm_client.llm_batch_generate(question_list, ts_list)

    for idx, answer in enumerate(answer_list):
        answer_dict[idx] = {
            "response": answer
        }

    return answer_dict


if __name__ == '__main__':
    dataset = json.load(open(DATASET)) if DATASET.endswith('.json') else [json.loads(line) for line in open(DATASET, 'rt')]

    generated_answer = []
    exp_dir = os.path.join(WORKDIR, f"exp/{EXP}")
    logger.info(f"Experiment directory: {exp_dir}")
    os.makedirs(exp_dir, exist_ok=True)

    # Generation
    logger.info("Start Generation...")
    question_list = []
    ts_list = []
    
    for idx in range(len(dataset)):
        sample = dataset[idx]
        question_list.append(sample['question'])
        ts_list.append([np.array(item) for item in sample['timeseries']])
    
    # Generate answers
    answer = answer_question_list(question_list, ts_list)
    
    # Prepare results
    for idx, ans in answer.items():
        generated_answer.append({
            'idx': idx,
            'question_text': question_list[idx],
            'response': ans['response']
        })

    # Save results
    output_file = os.path.join(exp_dir, f"generated_answer.json")
    json.dump(generated_answer, open(output_file, "wt"), ensure_ascii=False, indent=4)
    logger.info(f"Results saved to {output_file}")
