import json
import os
import subprocess
from collections import defaultdict
from typing import Any

import yaml

from data_generation.utils import (
    calc_openai_cost,
    check_overwrite,
    print_generation,
)
from data_generation.wikihow.generate_data import is_qualified_api_call


def generate_data(
    data: list[dict[str, Any]],
    save_file: str,
    model: str,
    rate_limit: int,
    token_limit: int,
    max_tokens: int,
    prompt_file: str,
    prompt_version: str,
    stop_token: str,
    temperature: float = 0.0,
    top_p: float = 1.0,
) -> None:

    with open(prompt_file, "r") as f:
        prompt = yaml.safe_load(f)[prompt_version]

    batch = []
    for example in data:
        cur_messages = [
            {
                "role": "user",
                "content": prompt["user_message"]
                .replace("__task__", example["task"])
                .replace("__past_actions__", example["prev_actions"])
                .replace(
                    "__next_action__",
                    example["next_action"].split("# step summary")[0].strip(),
                )
                .strip(),
            },
        ]
        # construct the json for the request body
        cur_body = {
            "model": model,
            "messages": cur_messages,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "top_p": top_p,
        }
        batch.append(cur_body)

    print(f"Number of examples: {len(batch)}")

    # save the request body to a file
    request_file = (
        f"{save_file.replace('.response.jsonl', '.step1.request.jsonl')}"
    )

    with open(request_file, "w+") as f:
        for example in batch:
            f.write(json.dumps(example) + "\n")

    check_overwrite(save_file)
    # use subprocess to call the openai api
    process = subprocess.Popen(
        [
            "python",
            "llms/providers/openai_request_parallel.py",
            "--request_url",
            f"https://{os.environ['VIJAY_RESOURCE_NAME']}.openai.azure.com/openai/deployments/{os.environ['VIJAY_RESOURCE_NAME']}/chat/completions?api-version={os.environ['VIJAY_VERSION']}"
            if model.startswith("vijay")
            else "https://api.openai.com/v1/chat/completions",
            "--api_key",
            os.environ["VIJAY_API_KEY"]
            if model.startswith("vijay")
            else os.environ["OPENAI_API_KEY"],
            "--requests_filepath",
            request_file,
            "--save_filepath",
            save_file,
            "--max_requests_per_minute",
            str(rate_limit),
            "--max_tokens_per_minute",
            str(token_limit),
        ]
    )
    process.wait()


def get_data() -> list[dict[str, Any]]:
    raw_files = """parsed_popular_06-43.jsonl
parsed_popular_20-28_map.jsonl
parsed_popular_35-04.jsonl
parsed_popular_35-05.jsonl
parsed_popular_35-06.jsonl
parsed_popular_35-07.jsonl
parsed_popular_35-08.jsonl
parsed_popular_35-09.jsonl
parsed_popular_35-10.jsonl
parsed_popular_35-11.jsonl
parsed_popular_42-34.jsonl
parsed_popular_42-97.jsonl"""
    with open("refine_index.json", "r") as f:
        refine_index = json.load(f)

    data = []
    false_neg = 0
    for file_name in raw_files.split("\n"):
        file_name = file_name.strip()
        if file_name not in refine_index:
            continue
        with open(f"data/b/{file_name}", "r") as f:
            for l_idx, line in enumerate(f):
                line = json.loads(line)
                if l_idx in refine_index[file_name]:
                    prev_actions = line["prev_actions"]
                    # check the number of actions
                    api_lines = [
                        x
                        for x in prev_actions.split("\n")
                        if is_qualified_api_call(x)
                    ]
                    hover_lines = [
                        x for x in api_lines if x.strip().startswith("hover")
                    ]

                    if len(api_lines) <= 2:
                        continue

                    if len(api_lines) == 3 and len(hover_lines) < 3:
                        continue

                    if len(hover_lines) / len(api_lines) < 0.5:
                        false_neg += 1
                        continue

                    data.append(line)
    print(f"False negative: {false_neg}")
    print(f"Number of data: {len(data)}")
    return data


if __name__ == "__main__":
    data = get_data()
    save_file = "./data/b/refined_data.response.jsonl"
    generate_data(
        data,
        save_file=save_file,
        model="gpt-4o",
        rate_limit=1500,
        token_limit=200_000,
        max_tokens=4096,
        temperature=1.0,
        prompt_file="data_generation/refinement/prompts/prompt.yaml",
        prompt_version="v1",
        stop_token="```",
    )
    c1 = calc_openai_cost(save_file)
    # print_generation(save_file, print_query=True)
    print(c1)
