import json
import pickle
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from accelerate import Accelerator
import os
import sys
# Add the inference directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../inference')))

from scripts.inference.metaflow_workbench import *

os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '3600'

# =====================Config==========================
load_model_path = "metaflow_neurips/output/WorkBench_SFT_v2/checkpoint-79"

model_name = "qwen2.5-14b-instruct"
max_workers = 8
ICL_number = 0
model_size = 7
use_vllm = True
num_epochs = 2
num_generations = 14
rollout_temperature = 1.0

version = 1
beta = 0

dataset_path = './input/GRPO_data.pkl'
filename = os.path.basename(dataset_path)
dataset_name = os.path.splitext(filename)[0]



agent_type_list = ["react", "code_lines_ratio"]
reward_weights = [0.7, 0.3]


if len(agent_type_list) ==1:
    reward_weights= None
experiment_name = f"WorkBench_v{version}_INST{model_name}_numG{num_generations}_numP{num_repetitions}_RollT{rollout_temperature}_ExecT{exec_temperature}_beta{beta}_{'_'.join(agent_type_list)}_{dataset_name}"
if load_model_path:
    filename = os.path.basename(load_model_path)
    ckpt_name = os.path.splitext(filename)[0]
    experiment_name += ckpt_name[-10:]
output_dir = f'metaflow_neurips/output/{experiment_name}'



if __name__ == "__main__":
    print('experiment_name:', experiment_name)
    print('dataset_path:', dataset_path)

    with open(dataset_path, 'rb') as fp:
        train_dataset = pickle.load(fp)

    for sample in train_dataset:
        sample['workflow1_lines'] = count_static_tool_calls(sample['workflow1'])
        sample['workflow2_lines'] = count_static_tool_calls(sample['workflow2'])


    # =======================================

    train_dataset = Dataset.from_list(train_dataset)

    train_dataset = train_dataset.map(
        lambda example: make_summarization_prompt(example, ICL_number=0, leaf_nodes=leaf_nodes))

    model_path = f'models/Qwen2.5-{model_size}B-Instruct'
    accelerator = Accelerator(mixed_precision='bf16', log_with="wandb")
    accelerator.init_trackers("MetaFlow", config={}, init_kwargs={"wandb": {"name": experiment_name}})

    if not load_model_path :
        load_model_path = model_path
    print('Load from:', load_model_path)

    model = AutoModelForCausalLM.from_pretrained(
        load_model_path,
        torch_dtype=torch.bfloat16,
        use_cache=False,
        attn_implementation="flash_attention_2",
    )
    model.gradient_checkpointing_enable()


    training_args = GRPOConfig(
        output_dir=output_dir,
        learning_rate=1e-6,
        remove_unused_columns=False,
        gradient_accumulation_steps=1,
        beta=beta,
        num_train_epochs=num_epochs,
        bf16=True,
        max_completion_length=8192,
        lr_scheduler_type="constant",
        num_generations=num_generations,
        temperature=rollout_temperature,
        max_prompt_length=8192,
        logging_steps=1,
        push_to_hub=False,
        save_strategy="steps",
        save_steps=100,
        per_device_train_batch_size=2,
        epsilon_high=0.28,
        log_completions=True,
        gradient_checkpointing=True,
        report_to="wandb",
        save_on_each_node=False,
        use_vllm=use_vllm,
        scale_rewards=False,
        reward_weights=reward_weights,
        num_iterations=1,
        overwrite_output_dir=True
    )

    if len(agent_type_list) == 1:
        reward_funcs = make_correct_ratio_reward_func(agent_type_list[0])
    else:
        reward_funcs = [make_correct_ratio_reward_func(agent_type) for agent_type in agent_type_list]


    trainer = GRPOTrainer(
        model=model,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_dataset,
    )
    trainer.train()
