import json
from transformers import AutoTokenizer
import random
from utils import dataset_paths, sample_indices
from datasets import load_dataset
import os
import argparse
import json
import random
from datetime import date
import numpy as np

random.seed(42)

# ERRORED = [322, 368, 401, 544, 967, 1054, 4986, 5018, 5081, 5085, 5091, 5117, 5142, 5151, 5155, 5163, 5183, 5225, 5235, 5254, 
#         5263, 5272, 5319, 5358, 5422, 5445, 5519, 5527, 5547, 5586, 5597, 5600, 5603, 5632, 5649, 5659, 5710, 5718, 5756,
#         5767, 5782, 5795, 5796, 5811, 5847, 5882, 6008, 6013, 6033, 6089, 6121, 6178, 6241, 6243, 6325, 6585, 6703, 6736,
#         6794, 6812, 6833, 6835, 6838, 6858, 7136, 7171, 7186, 7292, 8047, ]


def list_files(directory):
    return [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]


def format_data(item, model):
    """
    For DPO (OpenRLHF)
    """
    chosen_messages = [
        {
            "role": "user",
            "content": item["problem"]
        },
        {
            "role": "assistant",
            "content": item["chosen"]
        }
    ]
    
    rejected_messages = [
        {
            "role": "user",
            "content": item["problem"]
        },
        {
            "role": "assistant",
            "content": item["rejected"]
        }
    ]
    
    new_item = {
        "chosen": chosen_messages,
        "rejected": rejected_messages
    }
    return new_item


def generate_dataset(args, sampled_indices, tokenizer, verbose=False):

    def get_token_length(solution):
        length = len(tokenizer(solution)['input_ids'])
        return length
    
    selected_items = []
    qids = set()
    for qid in sampled_indices:
        num_tokens_list, solution_list, correct_list = np.array([]), np.array([]), np.array([])
        
        raw_file_path = os.path.join(args.load_dir, args.raw_subdir, f"{qid}.json")
        if os.path.exists(raw_file_path):
            with open(raw_file_path, "r", encoding="utf-8") as f:
                raw_data = json.load(f)
            
            ## rollout
            rollout_file_dir = os.path.join(args.load_dir, args.rollout_subdir, f"{qid}")
            if os.path.isdir(rollout_file_dir):
                problem = raw_data['problem']
                rollout_file_paths = list_files(rollout_file_dir)
                yids = np.array([int(s.split("/")[-1].split(".json")[0]) for s in rollout_file_paths])
                yids = np.array([i for i in yids if i in list(range(min(args.num_raw, len(raw_data["solution_list"]))))])
                
                num_tokens_list = np.array(raw_data["num_tokens_list"])[yids]
                solution_list = np.array(raw_data["solution_list"])[yids]
                correct_list = np.array(raw_data["correct_list"])[yids]
            
            for yid in yids:
                conclude_file_path = os.path.join(args.load_dir, args.conclude_subdir, f"{qid}/{yid}.json")
                if os.path.exists(conclude_file_path):
                    with open(conclude_file_path, "r", encoding="utf-8") as f:
                        conclude_data = json.load(f)
                    
                    partial_solution = conclude_data["partial_solution"]
                    index = partial_solution.find("<think>")
                    if index != -1:
                        partial_solution = partial_solution[index:]
                    else:
                        partial_solution = ""
                        import pdb; pdb.set_trace()
                    
                    if "<think> Hmm, I think this is enough to derive the final answer.\n\n**Final Answer**\n" in partial_solution:
                        continue
                    
                    if "<think>\n\n</think>" in partial_solution:
                        continue
                    
                    num_tokens_prev = get_token_length(partial_solution)
                    if np.mean(conclude_data["correct_list"]) == 1.:
                        num_tokens_list_conclude, solution_list_conclude, correct_list_conclude = [], [], []
                        for k, sol in enumerate(conclude_data["solution_list"]):
                            solution = partial_solution + sol
                            before = solution.split("</think>")[0]
                            if "</think>" not in solution:
                                continue
                            if (before.count("**Final Answer**") >= 2) or (before.count("boxed{") >= 2):
                                continue
                            if solution.count("\\boxed") < 2:
                                continue
                            solution_list_conclude.append(solution)
                            num_tokens_list_conclude.append(num_tokens_prev + conclude_data["num_tokens_list"][k])
                            correct_list_conclude.append(conclude_data["correct_list"][k])

                        conclude_indices = [i for i, c in enumerate(correct_list_conclude) if c == 1]
                        if conclude_indices:
                            min_index = min(conclude_indices, key=lambda i: num_tokens_list_conclude[i])
                            num_tokens_list = np.append(num_tokens_list, num_tokens_list_conclude[min_index])
                            solution_list = np.append(solution_list, solution_list_conclude[min_index])
                            correct_list = np.append(correct_list, correct_list_conclude[min_index])
                
                prune_file_path = os.path.join(args.load_dir, args.prune_subdir, f"{qid}/{yid}.json")
                if os.path.exists(prune_file_path):
                    with open(prune_file_path, "r", encoding="utf-8") as f:
                        prune_data = json.load(f)
                    summary = prune_data["summary"]
                    summary_num_tokens = prune_data["num_tokens"]
                    original_num_tokens = prune_data["original_num_tokens"]
                    if ("keep_remove2" in prune_data) and all(item == "KEEP AS IS" for item in prune_data["keep_remove2"]):
                        continue
                    if (summary_num_tokens < original_num_tokens) and (original_num_tokens - summary_num_tokens < 1000):
                        num_tokens_list = np.append(num_tokens_list, summary_num_tokens)
                        solution_list = np.append(solution_list, summary)
                        correct_list = np.append(correct_list, 1)

            if len(num_tokens_list) > 0:
                rejected = solution_list[list(num_tokens_list).index(max(num_tokens_list))]
                correct_solutions =  [(sol, tokens) for sol, tokens, correct in zip(solution_list, num_tokens_list, correct_list) if correct == 1]
                correct_solutions_sorted = sorted(correct_solutions, key=lambda x: x[1])
                chosen_list = [sol for sol, _ in correct_solutions_sorted[:args.used_count_per_problem]] if len(correct_solutions_sorted) >= args.used_count_per_problem else [sol for sol, _ in correct_solutions_sorted]
                if rejected:
                    for chosen in chosen_list:
                        selected_items.append(
                            {
                                "problem": problem,
                                "chosen": chosen,
                                "rejected": rejected,
                            }
                        )
                        qids.add(qid)

    dataset_data = [format_data(item, args.model_name) for item in selected_items]
    today_str = date.today().strftime("%Y-%m-%d")
    save_dir = f"data/my_dataset/{today_str}_{args.dataset}_model={args.model_name}_num_raw={args.num_raw}_used={args.used_count_per_problem}"
    
    print(f">>> Save dir = [{save_dir}]")
    os.makedirs(save_dir, exist_ok=True)
    
    filename = f"{save_dir}/raw.json"
    with open(filename,"w") as f:
        json.dump(dataset_data, f)
    
    dataset = load_dataset("json", data_files=filename)
    dataset.save_to_disk(f"{save_dir}")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_raw", type=int, default=4)
    parser.add_argument("--dataset", type=str, default='math_train')
    parser.add_argument("--num_inst", type=int, default=5000)
    parser.add_argument("--model_path", type=str, default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--model_name", type=str, default="DeepSeek-R1-Distill-Qwen-1.5B")
    parser.add_argument("--load_dir", type=str, default="")
    parser.add_argument("--raw_subdir", type=str, default="raw")
    parser.add_argument("--rollout_subdir", type=str, default="rollout")
    parser.add_argument("--conclude_subdir", type=str, default="conclude")
    parser.add_argument("--prune_subdir", type=str, default="prune")
    parser.add_argument("--used_count_per_problem", type=int, default=2)
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)
    
    input_path = dataset_paths[args.dataset]
    data = json.load(open(input_path,"r"))
    sampled_indices = sample_indices(data, args.dataset, args.num_inst)
    
    generate_dataset(args, sampled_indices, tokenizer, verbose=False)