import os
import json
import openai
import logging
import numpy as np
import pandas as pd
from time import sleep
from concurrent.futures import ThreadPoolExecutor, as_completed

from tools.openai import run_batch_instance, batch_create

logger = logging.getLogger("rich")


def run_batch_vllm(
    input_file: str,
    output_file: str | None = None,
    model: str | None = "google/Gemma-2-2B-it",
    client: openai.OpenAI | None = None,
    **kwargs
) -> None:
    """
    Run batch inference on VLLM model.
    """

    model_path = kwargs.get("model_path", model)
    chunksize = kwargs.get("chunksize", 50000)

    if not output_file:
        output_file = input_file.rsplit(".", 1)[0] + "_results.jsonl"
    
    command = f"""
    python -m vllm.entrypoints.openai.run_batch  \
    --model {model_path} \
    --served-model-name {model} \
    --trust-remote-code \
    --tensor-parallel-size {kwargs.get("tensor_parallel_size", 1)} \
    --gpu-memory-utilization {kwargs.get("gpu_memory_utilization", 0.8)} \
    --max-model-len {kwargs.get("max_model_len", 8192)} \
    --disable-log-requests \
    --dtype auto \
    --seed 42 \
    --enable-prefix-caching \
    """

    if chunksize:
        for ind, input_chunk in enumerate(pd.read_json(input_file, lines=True, chunksize=chunksize)):
            input_chunk.to_json(input_file.replace(".jsonl", f"_chunk-{ind}.jsonl"), orient='records', lines=True)
            input_chunk_file = input_file.replace(".jsonl", f"_chunk-{ind}.jsonl")
            output_chunk_file = output_file.replace(".jsonl", f"_chunk-{ind}.jsonl")
            os.system(command + f"--input-file {input_chunk_file} --output-file {output_chunk_file}")
            os.remove(input_chunk_file)
            pd.read_json(output_chunk_file, lines=True).to_json(output_file, orient='records', lines=True, mode='a')
            os.remove(output_chunk_file)
    else:
        os.system(command + f"--input-file {input_file} --output-file {output_file}")
    

def run_batch_openai(
    input_file: str,
    output_file: str | None = None,
    model: str | None = "google/Gemma-2-2B-it",
    client: openai.OpenAI | None = None,
    **kwargs
) -> None:
    """
    Run batch inference on OpenAI API batch mode.
    """
    chunksize = kwargs.get("chunksize", 10000)

    if not output_file:
        output_file = input_file.rsplit(".", 1)[0] + "_results.jsonl"
    
    task_ids = {}
    
    for ind, input_chunk in enumerate(pd.read_json(input_file, lines=True, chunksize=kwargs.get("chunksize", 10000))):
        input_chunk_file = input_file.replace(".jsonl", f"_chunk-{ind}.jsonl")
        output_chunk_file = output_file.replace(".jsonl", f"_chunk-{ind}.jsonl")

        input_chunk.to_json(input_chunk_file, orient='records', lines=True)

        task_id = batch_create(client, input_chunk_file)

        task_ids[task_id] = {
            "status": False,
            "output_file": output_chunk_file
        }
    
    while True:
        for task_id in task_ids:
            if task_ids[task_id]["status"]:
                continue
            task = client.batches.retrieve(task_id)
            if task.status in ["validating", "in_progress", "finalizing"]:
                pass
            elif task.status == "completed":
                file_response = client.files.content(task.output_file_id)
                with open(task_ids[task_id]["output_file"], "wb") as f:
                    f.write(file_response)
                task_ids[task_id]["status"] = True
            else:
                raise Exception(f"Task {task_id} failed with status {task.status}")
        
        if all([task_ids[task_id]["status"] for task_id in task_ids]):
            break

        num_finished = sum([task_ids[task_id]["status"] for task_id in task_ids])
        num_running = len(task_ids) - num_finished
        logger.info(f"{num_finished} tasks finished, {num_running} tasks running")
        sleep(kwargs.get("update_interval", 60))
    
    logger.info("All tasks finished, concatenating results...")
    for task_id in task_ids:
        pd.read_json(task_ids[task_id]["output_file"], lines=True).to_json(output_file, orient='records', lines=True, mode='a')
        os.remove(task_ids[task_id]["output_file"])
    logger.info(f"""Results concatenated to "{output_file}" """)

  
def run_batch_instant_api(
    input_file: str,
    output_file: str | None = None,
    model: str | None = "google/Gemma-2-2B-it",
    client: openai.OpenAI | None = None,
    **kwargs
) -> None:
    """
    Run batch inference on OpenAI API instant mode.
    """

    parallel_size = kwargs.get("parallel_size", 1)
    if not output_file:
        output_file = input_file.rsplit(".", 1)[0] + "_results.jsonl"

    output = open(output_file, "a")

    def process_line(line):
        row = json.loads(line)
        result = run_batch_instance(row, client)
        return json.dumps(result) + "\n"

    with open(input_file, "r") as f:
        lines = f.readlines()

    if parallel_size > 1:
        with ThreadPoolExecutor(max_workers=parallel_size) as executor:
            futures = [executor.submit(process_line, line) for line in lines]
            for future in as_completed(futures):
                output.write(future.result())
    else:
        for line in lines:
            output.write(process_line(line))
    
    output.close()
