import os
import copy
import json
import argparse
import random
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams

DATASET_SIZES = {
    "TabMWP": 5000,
    "gsm8k": 7470,
    "convfinqa": 2000
}

def parse_table(table):
    return "\n".join(["| " + " | ".join(row) + " |" for row in table])

def load_doc(dataset):
    examples = []
    if dataset in ["convfinqa", "gsm8k"]:
        for i in range(0, DATASET_SIZES[dataset], 10):
            with open(f"data/{dataset}/train_soln_{i}.jsonl", "r") as f:
                for line in f:
                    example = json.loads(line)
                    if dataset == "convfinqa":
                        parsed_table = parse_table(example["table"])
                        examples.append({
                            "question": example["qa"]["question"],
                            "answer": [example["qa"]["exe_ans"], example["qa"]["answer"]],
                            "passage": example["annotation"]["amt_post_text"],
                            "table": parsed_table
                        })
                    elif dataset == "gsm8k":
                        answer = example["answer"].split("####")[-1].strip()
                        examples.append({"question": example["question"], "answer": answer})
    elif dataset == "TabMWP":
        with open("data/TabMWP_train.json", "r") as f:
            data = json.load(f)
        for example in data.values():
            if example["grade"] <= 4 and random.random() > 0.2:
                continue
            examples.append({
                "question": example["question"],
                "answer": example["answer"],
                "table": example["table"],
                "grade": example["grade"]
            })
    return examples

def call_llm(prompt, model, tokenizer, sampling_params):
    messages = [{"role": "user", "content": prompt.strip()}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = model.generate([text], sampling_params)
    return outputs[0].outputs[0].text.strip()

parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer", type=str, default="meta-llama/Llama-3.1-8B-Instruct")
parser.add_argument("--model_path", type=str, default="")
parser.add_argument("--expname", type=str, default="")
parser.add_argument("--save_dir", type=str, default="outputs")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--top_p", type=float, default=0.995)
parser.add_argument("--tensor_parallel_size", type=int, default=1)
parser.add_argument("--round", type=int, default=2)
parser.add_argument("--n_decomp", type=int, default=4)
parser.add_argument("--n_code", type=int, default=4)
args = parser.parse_args()

tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
sampling_params = SamplingParams(
    temperature=args.temperature,
    top_p=args.top_p,
    repetition_penalty=1.05,
    max_tokens=1024
)
model = LLM(model=args.model_path, tensor_parallel_size=args.tensor_parallel_size, gpu_memory_utilization=0.89, trust_remote_code=True)

PROMPT_PLAN = {
    "gsm8k": "Please break down the question '{question}' into sub-questions. Use ###.",
    "TabMWP": "You have the following table:\n{table}\nQuestion: '{question}'. Break it down into sub-questions. Use ###.",
    "convfinqa": "Passages:\n{passage}\nTable:\n{table}\nQuestion: '{question}'. Break it down into sub-questions. Use ###."
}

PROMPT_CODE = {
    "gsm8k": "Question: '{question}'\nBreakdown:\n{decompose}\nWrite a Python program to solve it. Store result in 'ans'.",
    "TabMWP": "Table:\n{table}\nQuestion: '{question}'\nBreakdown:\n{decompose}\nWrite a Python program to solve it. Store result in 'ans'.",
    "convfinqa": "Passages:\n{passage}\nTable:\n{table}\nQuestion: '{question}'\nBreakdown:\n{decompose}\nWrite a Python program to solve it. Store result in 'ans'."
}

for dataset in ["convfinqa", "gsm8k", "TabMWP"]:
    examples = load_doc(dataset)
    random.shuffle(examples)
    examples = examples[:2500]
    output = []

    for ex in tqdm(examples):
        prompt = PROMPT_PLAN[dataset].format(
            question=ex.get("question", ""),
            table=ex.get("table", ""),
            passage=ex.get("passage", "")
        )
        decompose_dict = {}
        for _ in range(args.n_decomp):
            d = call_llm(prompt, model, tokenizer, sampling_params)
            decompose_dict[d] = 1
            if len(decompose_dict) >= 3:
                break
        for i, decompose in enumerate(decompose_dict):
            code_prompt = PROMPT_CODE[dataset].format(
                question=ex.get("question", ""),
                table=ex.get("table", ""),
                passage=ex.get("passage", ""),
                decompose=decompose.strip()
            )
            code_dict = {}
            for _ in range(args.n_code):
                c = call_llm(code_prompt, model, tokenizer, sampling_params)
                code_dict[c] = 1
                if len(code_dict) > 3:
                    break
            for j, code in enumerate(code_dict):
                output.append({
                    "question": ex["question"],
                    "answer": ex["answer"],
                    "passage": ex.get("passage", ""),
                    "table": ex.get("table", ""),
                    "decompose_id": i,
                    "decompose": decompose,
                    "code_id": j,
                    "code": code
                })
                if i == 0 and j == 0:
                    print(args.expname, args.round, output[-1])

    os.makedirs(f"{args.save_dir}/code", exist_ok=True)
    with open(f"{args.save_dir}/code/prompts_code_{dataset}_{args.expname}_r{args.round}.json", "w") as f:
        json.dump(output, f, indent=2)