import pickle
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
from utils import *
from accelerate import Accelerator
import os


os.environ["TOKENIZERS_PARALLELISM"] = "false"

os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '3600'
load_model_path = None
model_name = "Qwen2.5-14B-Instruct"
max_workers = 8
ICL_number = 3
model_size = 7
use_vllm = True
num_epochs = 3
num_generations = 14
rollout_temperature = 1.0
exec_temperature = 0.7
num_repetitions = 1

version = 13
beta = 0


dataset_path = 'input/GRPO_data.pkl'
filename = os.path.basename(dataset_path)
dataset_name = os.path.splitext(filename)[0]


agent_type_list = ["privileged_FullCodeRefl", "code_lines_ratio"]
reward_weights = [0.7, 0.3]

compress_ratio_threshold = 0.7

if len(agent_type_list) ==1:
    reward_weights= None
experiment_name = f'v{version}_INST{model_name}_numG{num_generations}_numP{num_repetitions}_RollT{rollout_temperature}_ExecT{exec_temperature}_beta{beta}_{agent_type_list}_{reward_weights}_{dataset_name}'
if load_model_path:
    filename = os.path.basename(load_model_path)
    ckpt_name = os.path.splitext(filename)[0]
    experiment_name += ckpt_name[-10:]
else:
    experiment_name += '_vanilla'
output_dir = f'metaflow_neurips/output/{experiment_name}'

EVAL_RESULTS_output_dir = f'metaflow_neurips/output/{experiment_name}/ALL_EVAL_RESULT.pkl'
ALL_EVAL_RESULTS = []


# =====================Config==========================






def correct_ratio_reward_func_v2(completions, fixed_agent_type, **kwargs) -> list[float]:
    global ALL_EVAL_RESULTS

    completion_contents = [completion[0]["content"] for completion in completions]
    # print(completion_contents)
    # ===========================

    rewards = []
    args_list = []  #
    meta_id = 0
    invalid_meta_ids = set()

    # print(kwargs['prompts'])
    for completion, leaf_ids in zip(completion_contents, kwargs['leaf_ids']):
        try:
            pattern_meta_query = r'\*\*Meta Query\*\*:\s*(.*?)\s*\*\*Meta Workflow\*\*:'
            pattern_code_blocks = r'```(json|python)\s*([\s\S]*?)```'

            meta_query_match = re.search(pattern_meta_query, completion, re.DOTALL)
            meta_query = meta_query_match.group(1).strip().strip('-')
            meta_workflow_raw = re.findall(pattern_code_blocks, completion, re.DOTALL)

            if str(completion).count('apis.supervisor.complete_task') > 1:
                raise Exception('Meta flow cannot normally complete task.')

            meta_workflow = []
            for block_type, block_content in meta_workflow_raw:
                if block_type == 'json':
                    meta_workflow.append(json.loads(block_content))
                else:
                    if is_fully_commented(block_content):
                        raise Exception('Meta flow static code only comprises comments.')
                    meta_workflow.append(block_content)
            if type(meta_workflow[-1]) is str and 'apis.supervisor.complete_task' not in meta_workflow[-1]:
                raise Exception('Meta flow cannot normally complete task.')


            if meta_query is None or len(meta_workflow_raw) == 0:
                raise Exception('Meta flow is invalid.')
            if all(isinstance(item, str) for item in meta_workflow):
                raise Exception('Meta flow only comprises strings.')


            rep_exec_temperature = exec_temperature
            temperature_step =  (exec_temperature - 0.0) / num_repetitions
            for rep_index in range(num_repetitions):
                rep_exec_temperature = max(rep_exec_temperature - temperature_step ,0)
                args_list += [(meta_id, meta_query, meta_workflow, leaf_id, experiment_name, model_name, fixed_agent_type, rep_index, rep_exec_temperature) for leaf_id in
                              leaf_ids]
            # ================================
        except Exception as e:
            print('Metaflow Exception:', e)
            invalid_meta_ids.add(meta_id)
        # ====================
        meta_id += 1

    if fixed_agent_type in ['privileged_FullCodeRefl', 'FullCodeRefl', 'react', "semi_privileged_FullCodeRefl"]:
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            eval_results = list(tqdm(
                executor.map(process_leaf, args_list),
                total=len(args_list), desc=f"Processing Metaflows."
            ))

        ALL_EVAL_RESULTS += eval_results
        for i in range(len(completion_contents)):
            if i in invalid_meta_ids:
                rewards.append(-1)
            else:
                max_reward_among_repetitions = -1
                for rep_index in range(num_repetitions):
                    sub_eval_results = [sample for sample in eval_results if sample['meta_id'] == i and sample['repe_index'] == rep_index]
                    num_passes = len([eval_result for eval_result in sub_eval_results if eval_result['eval_result'].get('success', False) == True ])
                    num_testes = len(sub_eval_results)
                    if num_testes == 0:
                        max_reward_among_repetitions = max_reward_among_repetitions
                    else:
                        max_reward_among_repetitions = max(max_reward_among_repetitions, num_passes / num_testes)
                rewards.append(max_reward_among_repetitions)

        return rewards
    elif fixed_agent_type == 'code_lines_ratio':
        for i in range(len(completion_contents)):
            if i in invalid_meta_ids:
                rewards.append(-1)
            else:
                completion = completion_contents[i]
                lines = count_python_lines_in_markdown(completion)

                compress_ratio = max(float(lines) / (kwargs['workflow1_lines'][i] + 0.001),
                                     float(lines) / (kwargs['workflow2_lines'][i] + 0.001))

                compress_ratio = min(compress_ratio, compress_ratio_threshold)

                compress_ratio = compress_ratio / compress_ratio_threshold


                rewards.append(compress_ratio)

        return rewards
    else:
        raise Exception(f'{fixed_agent_type} is not supported.')




def make_correct_ratio_reward_func(agent_type_inner):
    def reward_func(completions, **kwargs):
        return correct_ratio_reward_func_v2(
            completions, fixed_agent_type=agent_type_inner, **kwargs
        )
    reward_func.__name__ = f"correct_ratio_reward_func_v2_{agent_type_inner}"
    reward_func.__qualname__ = f"correct_ratio_reward_func_v2_{agent_type_inner}"
    return reward_func



if __name__ == "__main__":
    print('experiment_name:', experiment_name)
    print('dataset_path:', dataset_path)

    with open(dataset_path, 'rb') as fp:
        train_dataset = pickle.load(fp)

    for sample in train_dataset:
        sample['workflow1_lines'] = count_python_lines_in_markdown(sample['workflow1'])
        sample['workflow2_lines'] = count_python_lines_in_markdown(sample['workflow2'])


    # =======================================

    train_dataset = Dataset.from_list(train_dataset)

    train_dataset = train_dataset.map(
        lambda example: make_summarization_prompt(example, ICL_number=ICL_number, version=2))

    model_path = f'/share_data/data1/models/Qwen/Qwen2.5-{model_size}B-Instruct'
    accelerator = Accelerator(mixed_precision='bf16', log_with="wandb")
    accelerator.init_trackers("MetaFlow", config={}, init_kwargs={"wandb": {"name": experiment_name}})


    if not load_model_path :
        load_model_path = model_path
    print('Load from:', load_model_path)

    model = AutoModelForCausalLM.from_pretrained(
        load_model_path,
        torch_dtype=torch.bfloat16,
        use_cache=False,
        attn_implementation="flash_attention_2",
    )
    model.gradient_checkpointing_enable()



    training_args = GRPOConfig(
        output_dir=output_dir,
        # resume_from_checkpoint=load_model_path,
        learning_rate=1e-6,
        remove_unused_columns=False,  # to access the solution column in accuracy_reward
        gradient_accumulation_steps=1,
        beta=beta,
        num_train_epochs=num_epochs,
        bf16=True,
        # Parameters that control de data preprocessing
        max_completion_length=8192,  # default: 256
        lr_scheduler_type="constant",
        num_generations=num_generations,  # default: 8
        temperature=rollout_temperature,
        max_prompt_length=8192,  # default: 512
        # Parameters related to reporting and saving
        logging_steps=1,
        push_to_hub=False,
        save_strategy="steps",
        save_steps=100,
        per_device_train_batch_size=2,
        epsilon_high=0.28,
        log_completions=True,
        gradient_checkpointing=True,
        report_to="wandb",
        save_on_each_node=False,
        use_vllm=use_vllm,
        scale_rewards=False,
        reward_weights=reward_weights,
        num_iterations=1,
        overwrite_output_dir=True
    )

    if len(agent_type_list) == 1:
        reward_funcs = make_correct_ratio_reward_func(agent_type_list[0])
    else:
        reward_funcs = [make_correct_ratio_reward_func(agent_type) for agent_type in agent_type_list]


    trainer = GRPOTrainer(
        model=model,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=train_dataset,
    )
    trainer.train()


    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        with open(EVAL_RESULTS_output_dir, 'wb') as fp:
            pickle.dump(ALL_EVAL_RESULTS, fp)
