#!/user/bin/env python
# coding=utf-8
from openai import OpenAI
import logging
import time
import json
import re
import os
import time
import warnings
from tqdm import tqdm
from logging import getLogger
from typing import Dict, List, Optional, Union
import shutil
import multiprocessing
from tqdm import tqdm
from functools import partial
import requests
from logging import getLogger

from http import HTTPStatus
import json
from tqdm import tqdm
import pandas as pd
import time
from openai import OpenAI
import json
import os
import numpy as np
import requests
from decimal import *
from typing import Dict, Any
import multiprocessing

from transformers import AutoTokenizer
import sys
import json


import argparse

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--input_file",
        type=str,
        default="data.jsonl",
    )

    parser.add_argument(
        "--output_dir",
        type=str,
        default="results",
    )

    parser.add_argument(
        "--model_name",
        type=str,
        default="default_model",
    )

    parser.add_argument(
        "--tokenizer_path",
        type=str,
        default="Qwen/Qwen2.5-7B-Instruct",
    )

    parser.add_argument(
        "--port",
        type=int,
        default=8000,
    )

    parser.add_argument(
        "--run_time",
        type=str,
        default="1",
    )

    args = parser.parse_args()

    return args

args = parse_args()

MODEL_NAME = args.model_name
PORT = args.port

IP="localhost"
TOKEN_LEN_TOKENIZER = AutoTokenizer.from_pretrained(args.tokenizer_path)
MAX_LEN = 110000
TEMPERATURE=0.1

def token_length_tokenizer(text):
    encoded_inputs = TOKEN_LEN_TOKENIZER(text)
    tokens = encoded_inputs['input_ids']
    return len(tokens)

def string_to_token(text):
    encoded_inputs = TOKEN_LEN_TOKENIZER(text)
    tokens = encoded_inputs['input_ids']
    return tokens

def token_to_string(tokens):
    return  TOKEN_LEN_TOKENIZER.decode(tokens[:MAX_LEN], skip_special_tokens=True)


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    handlers=[
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def get_vllm_response(prompt: str = "",
                      output_path: str = "",
                      messages_tag: str ="",
                      response_tag: str = "",
                      is_save=True):
    """
    Sends a prompt to an API and returns the response.

    Parameters:
    prompt (str): The prompt containing the query and knowledge areas.
    model_name (str): The name of the model to use for the API call.
    api_key (str): The API key for authorization.
    output_file (str): The file path to save the response.

    Returns:
    str: The content of the response or an error message.
    """
    OPENAI_API_KEY = "EMPTY"
    OPENAI_API_BASE = f"http://{IP}:{PORT}/v1"

    max_retries = 3
    response_content = "Error"
    client = OpenAI(
        api_key=OPENAI_API_KEY,
        base_url=OPENAI_API_BASE,
    )

    for attempt in range(max_retries):
        try:
            MAX_TOKENS = 8192
            # Execute the API request
            response = client.chat.completions.create(
                model=MODEL_NAME,
                temperature=TEMPERATURE,
                messages=prompt[messages_tag],
		        max_tokens=MAX_TOKENS
            )
            # Handle the response appropriately
            response_content = response.choices[0].message.model_dump().get('content', "")
            prompt[response_tag] = response_content
            
            if is_save:
                if messages_tag in prompt:
                    del prompt[messages_tag]
                save_response(output_path, prompt)
            return prompt

        except Exception as error:
            print(f"Attempt {attempt + 1}/{max_retries}: Error encountered: {error}", flush=True)
    
    return prompt

def get_prompt_tpl(infer_mode):
    if infer_mode == "cot":
        prompt_tpl = """Please read the following text and answer the questions below.

<text>
$DOC$
</text>

What is the correct answer to this question: $Q$
Choices:
(A) $C_A$
(B) $C_B$
(C) $C_C$
(D) $C_D$

Let’s think step by step:"""
    elif infer_mode == "cot_ans":
        prompt_tpl = """Please read the following text and answer the questions below.

The text is too long and omitted here.

What is the correct answer to this question: $Q$
Choices:
(A) $C_A$
(B) $C_B$
(C) $C_C$
(D) $C_D$

Let’s think step by step: $COT$

Based on the above, what is the single, most likely answer choice? Format your response as follows: "The correct answer is (insert answer here)"."""
    elif infer_mode == "direct_cot":
        prompt_tpl = """Please read the following text and answer the question below.

<text>
$DOC$
</text>

What is the correct answer to this question: $Q$
Choices:
(A) $C_A$
(B) $C_B$
(C) $C_C$
(D) $C_D$

Let’s think step by step. And format your final answer choice as follows: "The correct answer is (insert answer here)"."""
    else:
        prompt_tpl = """Please read the following text and answer the question below.

<text>
$DOC$
</text>

What is the correct answer to this question: $Q$
Choices:
(A) $C_A$
(B) $C_B$
(C) $C_C$
(D) $C_D$

Format your response as follows: "The correct answer is (insert answer here)"."""
    return prompt_tpl


def process_line_worker(line, existing_keys_dict, messages_tag, infer_mode):
    """
    Worker function to process a single line.
    """
    try:
        if isinstance(line, str)
            record = json.loads(line.strip())
        else:
            record = line
        id = record.get('_id')
        question = record.get('question')
        choice_A = record.get('choice_A')
        choice_B = record.get('choice_B')
        choice_C = record.get('choice_C')
        choice_D = record.get('choice_D')
        context = record.get('context')

        if id and id in existing_keys_dict:
            return None
        
        if len(context) > MAX_LEN:        
            tokens = string_to_token(context)
            if len(tokens) > MAX_LEN:
                context = token_to_string(tokens)

        prompt_tpl = get_prompt_tpl(infer_mode)
        prompt = prompt_tpl.replace("$C_A$", choice_A).replace("$C_B$", choice_B).replace("$C_C$", choice_C).replace("$C_D$", choice_D).replace("$Q$", question).replace("$DOC$", context)
    
        messages = [{"role": "user", "content": prompt}]
        record[messages_tag] = messages
        return record

    except json.JSONDecodeError as e:
        logger.warning(f"JSON decode error: {e}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error processing line: {e}")
        return None

def process_line(input_path: str, existing_keys_dict: Dict[str, bool], messages_tag: str, infer_mode: str) -> list:
    """
    Processes lines in parallel using multiprocessing.
    """
    prompts = []
    try:
        if input_path.endswith(".jsonl")
            lines = open(input_path).readlines()
        else:
            with open(input_path, "r") as f:
                lines = json.load(f)
        # test
        # lines = open(input_path).readlines()[:10]
        with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool:
            worker = partial(process_line_worker, existing_keys_dict=existing_keys_dict, messages_tag=messages_tag, infer_mode=infer_mode)
            for result in tqdm(pool.imap(worker, lines), total=len(lines), desc="Processing lines"):
                if result:
                    prompts.append(result)
        return prompts

    except Exception as e:
        logger.error(f"Unexpected error in process_line: {e}")
        return None


def save_response(output_file: str, line:dict):
    with open(output_file, "a", encoding="utf-8") as file:
        file.write(json.dumps(line, ensure_ascii=False) + "\n")


def load_existing_keys(outfile_path: str) -> set:
    """Loads existing unique keys from the output file to avoid reprocessing."""
    existing_keys = set()
    if os.path.exists(outfile_path):
        with open(outfile_path, 'r', encoding='utf-8') as outfile:
            for line in outfile:
                try:
                    record = json.loads(line.strip())
                    chat_id = record.get('_id')
                    if chat_id:
                        existing_keys.add(chat_id)
                except json.JSONDecodeError:
                    logger.warning("Skipping malformed line while loading existing keys.")
    logger.info(f"Loaded {len(existing_keys)} existing keys from '{outfile_path}'.")
    return existing_keys

def req_diy(prompt: str = "",
          output_path: str = "",
          messages_tag: str ="",
          response_tag: str = "",
          infer_mode: str = ""):
    if infer_mode in ["direct", "direct_cot"]:
        return get_vllm_response(prompt, output_path, messages_tag, response_tag)
    elif infer_mode == "cot":
        _prompt_cot = get_vllm_response(prompt, output_path, messages_tag, "response_cot", is_save=False)
        _response_cot = _prompt_cot.get("response_cot")
        _prompt_tpl_cot_ans = get_prompt_tpl("cot_ans")
        
        question = _prompt_cot.get('question')
        choice_A = _prompt_cot.get('choice_A')
        choice_B = _prompt_cot.get('choice_B')
        choice_C = _prompt_cot.get('choice_C')
        choice_D = _prompt_cot.get('choice_D')
        context = _prompt_cot.get('context')
        
        _prompt_cot_ans = _prompt_tpl_cot_ans.replace("$C_A$", choice_A).replace("$C_B$", choice_B).replace("$C_C$", choice_C).replace("$C_D$", choice_D).replace("$Q$", question).replace("$DOC$", context).replace("$COT$", _response_cot)
        _prompt_cot[messages_tag] = [{"role": "user", "content": _prompt_cot_ans}]
        return get_vllm_response(_prompt_cot, output_path, messages_tag, "response", is_save=True)
    else:
        raise Exception("infer_mode error!")
    


if __name__ == '__main__':
    # Define your base prompt and file paths here

    messages_tag = f"messages"
    response_tag = f"response"
    infer_mode = "direct_cot" # ["direct_cot", "cot", "direct"]
    for run_idx in args.run_time.split(","):
        INPUT_FILE = args.input_file
        OUTPUT_FILE = f"{args.output_dir}/{MODEL_NAME}_{infer_mode}_{run_idx}.jsonl"
        existing_keys = load_existing_keys(OUTPUT_FILE)
        prompts = process_line(INPUT_FILE, existing_keys, messages_tag=messages_tag, infer_mode=infer_mode)
        print(len(prompts))
        batch = partial(req_diy, output_path=OUTPUT_FILE, messages_tag=messages_tag, response_tag=response_tag, infer_mode=infer_mode)
        with multiprocessing.Pool(processes=16) as pool:
            for _ in tqdm(pool.imap(batch, prompts), total=len(prompts)):
                pass