import argparse
import math
import os
from datetime import datetime

# import torch
from tqdm import tqdm
import json


from openrlhf.datasets import BenchmarkDataset
from openrlhf.utils import extract_last_answer, extract_all_answers, match_with_answer_labels_v3, match_with_answer_labels_v4

from datasets import load_dataset

from evaluation_scripts.evaluation.eval.eval_utils import math_equal

import torch.multiprocessing as mp


STOP_WORDS = {
    "qwen2": ["</s>", "<|im_end|>", "<|endoftext|>", "\n\nQuestion:", "<|end|>"],
    "llama": ["</s>", "<|endoftext|>", "\n\nQuestion:"],
} 

def main(args):

    class Empty:
        pass
    strategy = Empty()
    strategy.print = print
    strategy.is_rank_0 = lambda: True
    strategy.args = args

    strategy.print(f"Generate data from {args.pretrain} with {args.prompt_type} prompt type")
    strategy.print(f"args: {args}")
    # configure tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.pretrain, trust_remote_code=True)

    from vllm import LLM, SamplingParams
    # configure model
    llm = LLM(
        model=args.pretrain,
        tensor_parallel_size=args.tp_size,
        trust_remote_code=True,
        seed=args.seed,
        max_num_seqs=args.max_num_seqs,
        enable_prefix_caching=args.enable_prefix_caching,
        enforce_eager=True,
    )

    # Create a sampling params object.
    sampling_params = SamplingParams(
        max_tokens=args.max_new_tokens,
        top_p=args.top_p,
        top_k=args.top_k,
        temperature=args.temperature,
        stop=STOP_WORDS["qwen2" if "qwen" in args.pretrain.lower() else "llama"],
        stop_token_ids=(
                        [151645, 151643]
                        if "qwen2" in args.pretrain.lower()
                        else None
                    ),
        repetition_penalty=args.repetition_penalty,
        skip_special_tokens=True,
        truncate_prompt_tokens=args.prompt_max_len,
        include_stop_str_in_output=False,
        n=args.n_samples_per_prompt,
        logprobs=args.n_logprobs,

    )
    

    if args.dataset.endswith(".jsonl"):
        query_dataset = load_dataset("json", data_files=args.dataset)["train"]
    else:
        query_dataset = load_dataset(args.dataset)["train"]
        
    query_data = query_dataset.select(range(min(args.max_samples, len(query_dataset))))
    
    print(f"Loaded {len(query_data)} samples from {args.dataset} for generation.")
    # prepare dataset
    query_dataset = BenchmarkDataset(
        query_data,
        tokenizer,
        strategy,
        input_template=None,
        prompt_type=args.prompt_type,
    )

    strategy.print(f"Prepared generation dataset with {query_dataset.__len__()} samples.")

    strategy.print(f"Start generation for {args.dataset}...")

    data_ids = []
    prompts = []
    labels = []
    for data_id, prompt, label in query_dataset:    
        data_ids.append(data_id)
        prompts.append(prompt)
        labels.append(label)
    
    # batch size
    BATCH_SIZE = 1000
    os.makedirs(strategy.args.generation_log_path, exist_ok=True)

    for batch_idx in range(0, len(prompts), BATCH_SIZE):
        batch_queries = prompts[batch_idx: batch_idx + BATCH_SIZE]
        batch_data_ids = data_ids[batch_idx: batch_idx + BATCH_SIZE]
        batch_labels = labels[batch_idx: batch_idx + BATCH_SIZE]
        batch_id = batch_idx // BATCH_SIZE
        save_path = os.path.join(strategy.args.generation_log_path, f"batch_{batch_id}.jsonl")

        # skip if already processed
        if os.path.exists(save_path):
            print(f"Skipping batch {batch_id}, already exists.")
            continue

        print(f"Processing batch {batch_id} ({len(batch_queries)} queries) ...")

        # inference
        outputs = llm.generate(batch_queries, sampling_params)

        # import pdb
        # pdb.set_trace()
        # JSON friendly conversion
        batch_results = []
        for data_id, query, output, answer in zip(batch_data_ids, batch_queries, outputs, batch_labels):
            query_result = {"data_id": data_id, "query": query, "responses": []}

            acc_per_query = 0
            for resp in output.outputs:  # n responses
                response_text = resp.text
                token_logprobs = []


                if resp.logprobs:
                    for token_info in resp.logprobs:  # token level
                        sorted_top_logprobs = sorted(token_info.values(), key=lambda x: x.rank)
                        main_token_info = sorted_top_logprobs[0]
                        top_k_info = sorted_top_logprobs[1:]
                        token_entry = {
                                        "token": main_token_info.decoded_token,
                                        "logprob": main_token_info.logprob,
                                        "top_k_tokens": [item.decoded_token for item in top_k_info],
                                        "top_k_logprobs": [item.logprob for item in top_k_info]
                                }
                        token_logprobs.append(token_entry)

                    is_correct = bool(match_with_answer_labels_v4([response_text], [answer], answer_trigger=strategy.args.answer_trigger))
                    acc_per_query += is_correct
                    query_result["responses"].append({
                        "text": response_text,
                        "gold_answer": answer,
                        "is_correct": is_correct,
                        "tokens": [token_entry.token for token_entry in token_logprobs],
                        "logprobs": [token_entry.logprob for token_entry in token_logprobs],
                        "top_k_tokens": [token_entry.top_k_tokens for token_entry in token_logprobs],
                        "top_k_logprobs": [token_entry.top_k_logprobs for token_entry in token_logprobs],
                    })
                else:
                    is_correct = bool(match_with_answer_labels_v4([response_text], [answer], answer_trigger=strategy.args.answer_trigger))
                    acc_per_query += is_correct
                    query_result["responses"].append({
                        "text": response_text,
                        "gold_answer": answer,
                        "is_correct": is_correct,
                    })

            acc_per_query = acc_per_query / len(output.outputs)
            query_result["acc_per_query"] = acc_per_query
            if acc_per_query != 1.0 and acc_per_query != 0.0:
                batch_results.append(query_result)

        # save batch
        with open(save_path, "w", encoding="utf-8") as f:
            for res in batch_results:

                line = json.dumps(res, ensure_ascii=False)
                line = line.replace("\u2028", "\\u2028").replace("\u2029", "\\u2029")
                f.write(line + "\n")

        print(f"Saved batch {batch_id} → {save_path}")

    print("✅ All batches processed and saved.")

    # # Below is old code
    # generate_samples(llm, sampling_params, tokenizer, strategy, query_dataset, args.dataset)


    # # loop for all benchmarks
    # from pathlib import Path
    # path = Path(args.dataset)
    
    # ignore_benchmark_list = [""]
    # for benchmark_file in path.glob("*.jsonl"):
    #     benchmark_name = benchmark_file.stem
    #     benchmark_name = benchmark_name.strip()
        
    #     benchmark_file = str(benchmark_file)
    #     if benchmark_name in ignore_benchmark_list:
    #         strategy.print(f"Skipping benchmark: {benchmark_name}")
    #         continue
        
    #     strategy.print(f"Generate with {benchmark_name}")
        
    #     # # prepare for data and dataset
    #     # eval_data, _ = blending_datasets(
    #     #     benchmark_file,
    #     #     "1.0",
    #     #     strategy,
    #     #     args.seed,
    #     #     return_eval=False,
    #     #     max_count=args.max_samples,
    #     #     train_split="None",
    #     #     eval_split="None",
    #     # )
    #     query_data = load_dataset("json", data_files=benchmark_file)["train"]
    #     query_data = query_data.select(range(min(args.max_samples, len(query_data))))
        
    #     print(f"Loaded {len(query_data)} samples from {benchmark_file} for evaluation.")
    #     # prepare dataset
    #     query_dataset = BenchmarkDataset(
    #         query_data,
    #         tokenizer,
    #         strategy,
    #         input_template=None,
    #         prompt_type=args.prompt_type,
    #     )

    #     strategy.print(f"Prepared evaluation dataset with {query_dataset.__len__()} samples.")

    #     strategy.print(f"Start generation for {benchmark_name}...")
    #     evaluate(llm, sampling_params, tokenizer, strategy, query_dataset, benchmark_name)
    


def generate_samples(llm, sampling_params, tokenizer, strategy, eval_dataset, dataset_name: str):
    import time, datetime
    start_time = time.time()


    os.makedirs(strategy.args.generation_log_path, exist_ok=True)
    save_path = os.path.join(strategy.args.generation_log_path, f"{args.pretrain}_{dataset_name}_raw_data.jsonl")

    
    # with open(save_path, 'w', encoding="utf-8") as f:
    #     f.write(json.dumps(f"args: \n\n{strategy.args}\n") + "\n")
        
    
    data_ids = []
    prompts = []
    labels = []
    for data_id, prompt, label in eval_dataset:    
        data_ids.append(data_id)
        prompts.append(prompt)
        labels.append(label)
    
    output_dataset = []
    output_list = []
    outputs = llm.generate(prompts, sampling_params)
    generation_time = time.time()
    strategy.print(f"Generation finished, time: {generation_time - start_time:.2f}s")
    # import pdb
    # pdb.set_trace()


     
    results = []

    for query, output, answer in zip(prompts, outputs, labels):
        query_result = {"query": query, "responses": []}

        for resp in output.outputs:   
            response_text = resp.text
            token_logprobs = []

            for token_info in resp.logprobs:   
                token_entry = {
                    "token": token_info.decoded_token,
                    "top_logprobs": [
                        {"token": lp.decoded_token, "logprob": lp.logprob}
                        for lp in token_info.top_logprobs
                    ]
                }
                token_logprobs.append(token_entry)

            is_correct = bool(match_with_answer_labels_v4([response_text], [answer], answer_trigger=strategy.args.answer_trigger))
            
            query_result["responses"].append({
                "text": response_text,
                "gold_answer": answer,
                "is_correct": is_correct,
                "logprobs": token_logprobs,
            })

        results.append(query_result)

     
    with open("llm_outputs.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)







    for output in outputs:
        prompt = output.prompt
        output = output.outputs[0].text
        output_dataset.append({"input": prompt, "output": output})
        output_list.append(output)
        
    with open(save_path, 'a', encoding="utf-8") as f:
        for data_id, prompt, generation, answer in zip(data_ids, prompts, output_list, labels):
            generation_dict = {"data_id": data_id, "prompt": prompt, "generation": generation, 
                                #"extracted_answers": extract_first_numeric_answer(generation, strategy.args.answer_trigger),
                                "extracted_last_answers": extract_last_answer(generation, strategy.args.answer_trigger),
                                "extracted_all_answers": extract_all_answers(generation, strategy.args.answer_trigger),
                                "gold_answers": answer,
                                "is_correct": bool(match_with_answer_labels_v4([generation], [answer], answer_trigger=strategy.args.answer_trigger))}
            f.write(json.dumps(generation_dict, ensure_ascii=False) + "\n")

            # acc += match_with_answer_labels(tokenized_output, answers, answer_trigger=strategy.args.answer_trigger)
            # acc += match_with_answer_labels_v3(outputs, labels, answer_trigger=strategy.args.answer_trigger)
    evaluation_time = time.time()
    acc = match_with_answer_labels_v4(output_list, labels, answer_trigger=strategy.args.answer_trigger)
    strategy.print(f"Evaluation finished, accuracy: {acc:.4f}")
    with open(save_path, 'a', encoding="utf-8") as f:
        f.write(json.dumps(f"Evaluation finished, accuracy: {acc:.4f}, generation time: {generation_time - start_time:.2f}s, evaluation time: {evaluation_time - generation_time:.2f}s\n") + "\n")





if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    from transformers import AutoTokenizer
    parser = argparse.ArgumentParser()
    
    # VLLM
    parser.add_argument("--tp_size", type=int, default=1)
    parser.add_argument("--max_num_seqs", type=int, default=256)
    parser.add_argument("--enable_prefix_caching", action="store_true", default=False)
    parser.add_argument("--max_new_tokens", type=int, default=1024)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--top_k", type=int, default=40)
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--repetition_penalty", type=float, default=1.05)
    parser.add_argument("--prompt_max_len", type=int, default=1024)
    parser.add_argument("--n_samples_per_prompt", type=int, default=8)
    parser.add_argument("--n_logprobs", type=int, default=None)
    # Checkpoint
    parser.add_argument("--save_hf_ckpt", action="store_true", default=False)
    parser.add_argument("--disable_ds_ckpt", action="store_true", default=False)
    parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_sft")
    parser.add_argument("--max_ckpt_num", type=int, default=3)
    parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
    parser.add_argument("--load_checkpoint", action="store_true", default=False)
    parser.add_argument("--use_ds_universal_ckpt", action="store_true", default=False)

    # DeepSpeed
    parser.add_argument("--micro_train_batch_size", type=int, default=8, help="batch size per GPU")
    parser.add_argument("--micro_eval_batch_size", type=int, default=1, help="batch size per GPU")
    parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
    parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
    parser.add_argument("--torch_compile", action="store_true", default=False)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--full_determinism",
        action="store_true",
        default=False,
        help="Enable reproducible behavior during distributed training",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
    parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
    parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
    parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
    parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer")
    parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
    parser.add_argument("--use_liger_kernel", action="store_true", default=False, help="Enable Liger Kernel")
    parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
    parser.add_argument("--overlap_comm", action="store_true", default=False)
    parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
    parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)

    # SFT

    parser.add_argument("--pretrain", type=str, default=None)

    # ring-attention
    parser.add_argument("--ring_attn_size", type=int, default=1, help="Ring attention group size")
    parser.add_argument(
        "--ring_head_stride",
        type=int,
        default=1,
        help="the number of heads to do ring attention each time. "
        "It should be a divisor of the number of heads. "
        "A larger value may results in faster training but will consume more memory.",
    )

    # LoRA
    parser.add_argument("--load_in_4bit", action="store_true", default=False)
    parser.add_argument("--lora_rank", type=int, default=0)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
    parser.add_argument("--lora_dropout", type=float, default=0)
    parser.add_argument("--save_merged", type=bool, default=True)

    # packing SFT samples without CrossAttention
    parser.add_argument("--packing_samples", action="store_true", default=False)

    # custom dataset
    parser.add_argument("--dataset", type=str, default=None)
    parser.add_argument("--dataset_probs", type=str, default="1.0", help="sampling probs for datasets")
    parser.add_argument("--train_split", type=str, default="train", help="train split of the HF dataset")
    parser.add_argument("--eval_split", type=str, default="test", help="test split of the dataset")
    parser.add_argument("--multiturn", action="store_true", default=False, help="Use compacted multiturn dataset")

    parser.add_argument("--input_key", type=str, default="question", help="JSON dataset key")
    parser.add_argument("--output_key", type=str, default=None, help="JSON dataset key")
    parser.add_argument("--answer_key", type=str, default="answer", help="JSON dataset key")
    parser.add_argument("--data_id_key", type=str, default=None, help="JSON dataset key")
    parser.add_argument("--answer_trigger", type=str, default="The answer is:", help="Trigger for answer extraction")
    parser.add_argument("--input_template", type=str, default="User: {}\nAssistant: ")
    parser.add_argument("--prompt_type", type=str, default="qwen2-boxed")
    parser.add_argument(
        "--aRePOy_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
    )
    parser.add_argument("--tokenizer_chat_template", type=str, default=None)
    parser.add_argument("--max_samples", type=int, default=1e8, help="Max number of samples")
    parser.add_argument("--max_len", type=int, default=2048, help="Max tokens for the samples")

    # Generation configs
    # parser.add_argument("--max_len", type=int, default=512)
    # parser.add_argument("--top_p", type=float, default=0.95, help="Top-p sampling")
    # parser.add_argument("--temperature", type=float, default=1.0)
    # parser.add_argument("--repetition_penalty", type=float, default=1.2, help="Repetition penalty for generation")
    parser.add_argument("--generation_log_path", type=str, default="./generation_logs/evals/RePO")
    
    
    
    # wandb parameters
    parser.add_argument("--use_wandb", type=str, default=None)
    parser.add_argument("--wandb_org", type=str, default=None)
    parser.add_argument("--wandb_group", type=str, default=None)
    parser.add_argument("--wandb_project", type=str, default="openrlhf_train_sft")
    parser.add_argument(
        "--wandb_run_name",
        type=str,
        default="sft_%s" % datetime.now().strftime("%m%dT%H:%M"),
    )

    # TensorBoard parameters
    parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")

    # ModelScope parameters
    parser.add_argument("--use_ms", action="store_true", default=False)

    args = parser.parse_args()

    args.generation_log_path = args.generation_log_path + f"_{datetime.now().strftime('%m%dT%H%M')}"

    if args.multiturn:
        assert args.aRePOy_chat_template, "aRePOy_chat_template must be enabled when using multiturn format"

    if args.input_template and "{}" not in args.input_template:
        print("[Warning] {} not in args.input_template, set to None")
        args.input_template = None

    if args.input_template and "\\n" in args.input_template:
        print(
            "[Warning] input_template contains \\n chracters instead of newline. "
            "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
        )
        args.input_template = args.input_template.encode().decode('unicode_escape')

    if args.packing_samples and not args.flash_attn:
        print("[Warning] Please --flash_attn to accelerate when --packing_samples is enabled.")
        args.flash_attn = True

    if args.ring_attn_size > 1:
        assert args.packing_samples, "packing_samples must be enabled when using ring attention"

    if args.use_ms:
        from modelscope.utils.hf_util import patch_hub

        # Patch hub to download models from modelscope to speed up.
        patch_hub()

    main(args)
