import openai
from loguru import logger
import numpy as np
import json
import os
from tqdm import tqdm
import traceback
from typing import *

from evaluation.evaluate_qa import evaluate_batch_qa
from multiprocessing import Pool

# CONFIG
MODEL = 'gpt-5'
EXP = 'gpt-5-text-mcq2'
DATASET = 'evaluation/dataset/uts_reason.jsonl'
OPENAI_API_KEY = "[OPENAI_API_KEY]"
OPENAI_BASE_URL = "[OPENAI_BASE_URL]"
NUM_WORKERS = 32


def ask_gpt_api_with_timeseries(case_idx: int, timeseries: np.ndarray, cols: List[str], question: str) -> str:
    openai.api_key = OPENAI_API_KEY
    openai.base_url = OPENAI_BASE_URL

    client = openai.OpenAI(api_key=openai.api_key, base_url=openai.base_url)

    prompt_list = question.split('<ts><ts/>')
    prompt = prompt_list[0]
    for ts in range(len(timeseries)):
        cur_ts = ','.join([f"{i:.2f}" for i in timeseries[ts]])
        prompt += f"{cur_ts}" + prompt_list[ts + 1]

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
            ]
        }
    ]

    timeout_cnt = 0
    while True:
        if timeout_cnt > 2:
            logger.error("Too many timeout!")
            return "", 0
        try:
            response = client.chat.completions.create(
                model=MODEL,
                messages=messages,
                timeout=200
            )
            break
        except Exception as err:
            logger.error(err)
            logger.error("API timeout, trying again...")
            timeout_cnt += 1

    answer = response.choices[0].message.content
    total_tokens = response.usage.prompt_tokens
    return answer, total_tokens

def process_sample(args):
    sample, idx = args
    try:
        timeseries = sample['timeseries']
        cols = sample['cols']
        question_text = sample['question']
        label = sample['answer']

        answer, total_tokens = ask_gpt_api_with_timeseries(idx, timeseries, cols, question_text)

        return {
            'idx': idx,
            'question_text': question_text,
            'response': answer,
            'num_tokens': total_tokens
        }
    except Exception as err:
        logger.error(err)
        traceback.print_exc()
        return None


if __name__ == '__main__':
    dataset = json.load(open(DATASET)) if DATASET.endswith('.json') else [json.loads(line) for line in open(DATASET, 'rt')]

    generated_answer = []
    os.makedirs(f'exp/{EXP}', exist_ok=True)
    if os.path.exists(f"exp/{EXP}/generated_answer.json"):
        generated_answer = json.load(open(f"exp/{EXP}/generated_answer.json"))
    generated_idx = set([i['idx'] for i in generated_answer])

    # Generation
    logger.info("Start Generation...")
    idx_to_generate = [i for i in range(len(dataset)) if i not in generated_idx]
    with Pool(processes=NUM_WORKERS) as pool:
        results = list(tqdm(pool.imap(process_sample, [(dataset[idx], idx) for idx in idx_to_generate]), total=len(idx_to_generate)))

    # Filter out None results and update generated_answer
    generated_answer.extend([res for res in results if res is not None])
    json.dump(generated_answer, open(f"exp/{EXP}/generated_answer.json", "wt"), ensure_ascii=False, indent=4)

    # Evaluation
    evaluate_batch_qa(dataset, generated_answer, EXP, num_workers=NUM_WORKERS)
