import os
import json
import argparse
from tqdm import tqdm
from vllm import LLM, SamplingParams

parser = argparse.ArgumentParser()
parser.add_argument("--data_name", type=str, default="//test.json", help="dataset path")
parser.add_argument("--n_sampling", type=int, default=1, help="Number of sampling")
parser.add_argument("--max_tokens", type=int, default=2048, help="Max tokens per generation")
parser.add_argument("--temperature", type=float, default=0, help="Sampling temperature")
parser.add_argument("--model_name", type=str, help="Path to the model")
parser.add_argument("--output_dir", type=str, default="/output/", help="Output directory")
parser.add_argument("--few_shot", action="store_true", help="Enable few-shot mode", default=False)
args = parser.parse_args()

available_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")



llm = LLM(
    model=args.model_name,
    tensor_parallel_size=len(available_gpus),
    trust_remote_code=False,
    gpu_memory_utilization=0.95,
    max_model_len=4096,
    dtype="bfloat16"
)

with open(args.data_name, "r", encoding="utf-8") as f:
    examples = json.load(f)


## for jsonl file
# examples = []
# with open(args.data_name, 'r', encoding="utf-8") as f:
#     for line in f:
#         data = json.loads(line)
#         examples.append(data)

samples = []
for idx, example in tqdm(enumerate(examples), total=len(examples)):
# for idx, example in tqdm(enumerate(examples[:3]), total=3):
    sample = {
            "idx": idx,
            "question": example["question"],
        }
    samples.append(sample)

input_prompts = [sample["question"] for sample in samples for _ in range(args.n_sampling)]


# if not args.few_shot:
#     input_prompts = [ 
#         [{"role": "system",
#          "content": "You are a SQL generator. Given a schema and question, output ONLY the SQL query ending with a semicolon. No explanation."},
#          {"role": "user",
#         "content": prompt}] for prompt in input_prompts
#        ]


# if not args.few_shot:
#     input_prompts = [ 
#         [{"role": "system",
#          "content": "You are a SQL generator. Given a schema and question, output ONLY the SQL query ending with a semicolon. No explanation."},
#          {"role": "user",
#         "content": "Task: TEXT_TO_SQL \n" + prompt}] for prompt in input_prompts
#        ]


# if not args.few_shot:
#     input_prompts = [
#             prompt for prompt in input_prompts]

# if not args.few_shot:
#     input_prompts = [
#             f"""System: You are a SQL query generator. Based on the given schema and question, you should only generate the corresponding SQL query. Do not include any additional text or debug information.\nHuman: {prompt}\nAssistant:""" for prompt in input_prompts]
# else:
#         input_prompts = [f"""Q: {prompt} A: Let's think step by step.""" for prompt in input_prompts]

# outputs = llm.generate(
#     input_prompts,
#     SamplingParams(
#         temperature=args.temperature,
#         max_tokens=args.max_tokens,
#         skip_special_tokens=False,
#         stop=['<|EOT|>','<|eot_id|>'],  # stop generation when any of these strings is generated
#         ),
#     )


if not args.few_shot:
    input_prompts = [ 
        [{"role": "system",
         "content": "Given the database schema and a question in natural language, generate the corresponding SQL query. No explanation."},
         {"role": "user",
        "content": prompt}] for prompt in input_prompts
       ]

chat_template = ""
with open('/text2sql/tool_chat_template_llama3.1_json.jinja') as f:
    chat_template = f.read()
outputs = llm.chat(
    input_prompts,
    SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens,
        skip_special_tokens=False,
        stop=['<|EOT|>','<|eot_id|>'],  # stop generation when any of these strings is generated
        ),
        chat_template = chat_template
    )

# for output in outputs:
#      print(output)
#      break



outputs = [output.outputs[0].text for output in outputs]
for i, sample in enumerate(samples):
    output = outputs[i * args.n_sampling : (i + 1) * args.n_sampling]
    sample.update({"outputs": output})

results = []
for sample in samples:
    llm_outputs = sample["outputs"]

with open(os.path.join(args.output_dir,'pred.sql'), "w", encoding="utf-8") as f:
    for sample in samples:
        llm_outputs = sample["outputs"]
        f.write(llm_outputs[0].strip() + "\n")  # 每句 SQL 以分号结尾

print("✅ success, save in pred.sql")
