import json
import os
import pprint
import asyncio
from datetime import datetime
from time import sleep
from tqdm import tqdm
import argparse
from collections import Counter
from distutils.util import strtobool
import multiprocessing as mp


from src.evol.data_utils import load_data
from src.evol.openai_backend import call_chatgpt, LLM
from src.utils.data_utils import extract_answer_math, extract_answer_number
from src.utils.code_utils import execute_tora
from src.utils.math_utils import compare_ans, vote
from src.utils.file_utils import load_jsonl


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--prompt_path", default=None, type=str)
    parser.add_argument("--strategy_path", default=None, type=str)
    parser.add_argument("--dataset", default="gsm", type=str)
    parser.add_argument("--data_path", default=None, type=str)
    parser.add_argument("--model", default="gpt-3.5-turbo", type=str)
    parser.add_argument("--temperature", default=0.0, type=float)
    parser.add_argument("--top_p", default=1.0, type=float)
    parser.add_argument("--max_tokens", default=1024, type=int)
    parser.add_argument("--num_seqs", default=1, type=int)
    parser.add_argument("--num_skips", default=0, type=int)
    parser.add_argument("--input_col", default="question", type=str)
    parser.add_argument("--output_col", default="answer", type=str)
    parser.add_argument("--max_iter", default=3, type=int)
    parser.add_argument("--num_process", default=1, type=int)
    parser.add_argument("--output_path", default=None, type=str)
    parser.add_argument("--batch_size", default=10, type=int)
    args = parser.parse_args()
    return args


def load_prompt(prompt_path):
    with open(prompt_path, "r", encoding="utf-8") as fp:
        prompt = fp.read().strip()
    return prompt


def parse_output(output):
    """
    Parse the output into dict. Here is examples:
    ## Input
    S1: C3
    S2: C3, S1
    S3: C2
    S4: S2, S3
    ## Output
    {"S1": ["C3"], "S2": ["C3", "S1"], "S3": ["C2"], "S4": ["S2", "S3"]}
    """
    output = output.split("\n")
    print(output)
    output = [o.split(":") for o in output]
    output = {o[0].strip(): [i.strip() for i in o[1].split(",")] for o in output}
    return output


def stop_tora(result, parsed_question, solution):
    if "C" not in result:
        return False
    has_s, has_c = False, False
    try:
        output = parse_output(result)
    except:
        return False
    for i, key in enumerate(output):
        if key != f"S{i + 1}":
            return False
        for item in output[key]:
            if item.startswith("S"):
                has_s = True
                if item not in output:
                    return False
                if item >= key:
                    return False
                if item not in solution:
                    return False
            elif item.startswith("C"):
                has_c = True
                if item not in parsed_question:
                    return False
    if not has_s or not has_c:
        return False
    return True


def batch_tora(llm, batch_messages_base, parsed_questions, solutions):
    max_iters = 3
    itr = 0
    final_outputs = [None for _ in range(len(batch_messages_base))]
    batch_messages = batch_messages_base
    while itr < max_iters:
        batch_outputs = asyncio.run(
            llm.achat(
                batch_messages,
                model=args.model,
                stop=["## Question"],
                max_tokens=args.max_tokens,
                temperature=args.temperature + itr * 0.2,
                num_beams=args.num_seqs,
            )
        )
        should_continue = False
        for i, outputs in enumerate(batch_outputs):
            output = outputs[0]
            if output is not None and stop_tora(
                output, parsed_questions[i], solutions[i]
            ):
                final_outputs[i] = output
            else:
                should_continue = True
        if not should_continue:
            break
        batch_messages = [
            batch_messages_base[i]
            for i, output in enumerate(final_outputs)
            if output is None
        ]
        itr += 1
    return final_outputs


def main(args, samples, idx):
    # load prompt
    prompt = load_prompt(args.prompt_path)
    if idx <= 0:
        print(prompt)
    os.makedirs(f"result/{args.model}/{args.dataset}", exist_ok=True)
    if args.output_path is None:
        output_path = f"result/{args.model}/{args.dataset}/t{args.temperature}_n{args.num_seqs}-faith.jsonl"
    else:
        output_path = args.output_path
    print("%" * 30, "Tora", "%" * 30)
    print("Start PID %d and save to %s" % (os.getpid(), output_path))

    if idx != -1:
        output_path = output_path.replace(".jsonl", f"_{idx}.jsonl")
    save_samples, scores = [], []
    samples = samples[args.num_skips :]
    llm = LLM()
    batch_size = args.batch_size
    print(len(samples))
    print(samples[0]["question"])
    with open(output_path, "w" if args.num_skips == 0 else "a") as f:
        for i in tqdm(range(0, len(samples), batch_size)):
            start = i
            batch_samples = samples[start : start + batch_size]
            batch_messages = [
                [
                    {
                        "role": "system",
                        "content": "You are a helpful expert for math problem solving.",
                    },
                    {
                        "role": "user",
                        "content": prompt.replace("{question}", s["question"])
                        .replace("{parsed_question}", s["parsed_question"])
                        .replace("{solution}", s["solution"]),
                    },
                ]
                for s in batch_samples
            ]
            parsed_questions = [s["parsed_question"] for s in batch_samples]
            solutions = [s["solution"] for s in batch_samples]
            batch_outputs = batch_tora(llm, batch_messages, parsed_questions, solutions)
            print(batch_outputs)
            for s, outputs in zip(batch_samples, batch_outputs):
                save_sample = s
                save_sample["generation"] = outputs
                save_samples.append(save_sample)
                f.write(json.dumps(save_sample, ensure_ascii=False, indent=4) + "\n")
                f.flush()
    print(f"Accuracy - {sum(scores) / len(scores)}")


if __name__ == "__main__":
    args = parse_args()
    # samples = load_data(args.dataset, args.data_path)
    samples = load_jsonl(args.data_path)
    if args.num_process == 1:
        main(args, samples, idx=-1)
    else:
        num_each_split = int(len(samples) / args.num_process)
        p = mp.Pool(args.num_process)
        for idx in range(args.num_process):
            start = idx * num_each_split
            if idx == args.num_process - 1:
                end = max((idx + 1) * num_each_split, len(samples))
            else:
                end = (idx + 1) * num_each_split
            split_data = samples[start:end]
            p.apply_async(
                main,
                args=(
                    args,
                    split_data,
                    idx,
                ),
            )
        p.close()
        p.join()
        print("All of the child processes over!")
