import os
from thinktime.utils.llm_utils import LLMClient
from vllm import LLM, SamplingParams
import json
from loguru import logger
from evaluation.evaluate_qa import evaluate_qa, evaluate_batch_qa
from transformers import AutoTokenizer


# CONFIG
EXP = 'qwen2.5_14b_text_uts_reason'

DATASET = 'evaluation/dataset/uts_reason.jsonl'
MODEL_PATH = '[LOCAL_LLM_PATH]'
NUM_GPUS = 8
GPUS_PER_MODEL = 2

sampling_params = SamplingParams(
    max_tokens=512,
    temperature=0.2
)

if __name__ == '__main__':
    dataset = json.load(open(DATASET)) if DATASET.endswith('.json') else [json.loads(line) for line in open(DATASET, 'rt')]
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

    generated_answer = []
    os.makedirs(f"exp/{EXP}", exist_ok=True)
    if os.path.exists(f"exp/{EXP}/generated_answer.json"):
        generated_answer = json.load(open(f"exp/{EXP}/generated_answer.json"))
    generated_idx = set([i['idx'] for i in generated_answer])

    # Generation
    logger.info("Start Generation...")
    question_list = []
    total_num_tokens = 0
    for idx in range(len(dataset)):
        if idx in generated_idx:
            continue

        sample = dataset[idx]

        # Scaler
        prompt_list = sample['question'].split('<ts><ts/>')
        prompt = prompt_list[0]
        for ts in range(len(sample['timeseries'])):
            cur_ts = [str(round(float(i), 1)) for i in sample['timeseries'][ts]]
            prompt += ','.join(cur_ts) + prompt_list[ts + 1]
        question_list.append(prompt)
        total_num_tokens += len(tokenizer(prompt)['input_ids'])

    print(f"--> {total_num_tokens=}")

    if len(question_list) == 0:
        answer_list = []
    else:
        client = LLMClient(model_path=MODEL_PATH, engine='vllm', num_gpus=NUM_GPUS, gpus_per_model=GPUS_PER_MODEL, batch_size=16)
        client.wait_for_ready()
        answer_list = client.llm_batch_generate(question_list, use_chat_template=True, sampling_params=sampling_params)
        client.kill()

    for idx, ans in enumerate(answer_list):
        generated_answer.append({
            'idx': idx,
            'question_text': question_list[idx],
            'response': ans
        })

    # Save label
    json.dump(generated_answer, open(f"exp/{EXP}/generated_answer.json", "wt"), ensure_ascii=False, indent=4)

    # Evaluation
    evaluate_batch_qa(dataset, generated_answer, EXP, num_workers=16)
