import os
import torch
import warnings
import json
import random
from PIL import Image
from peft import LoraConfig, get_peft_model, TaskType
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor
)
from trl import GRPOTrainer, GRPOConfig
from datasets import Dataset
# from utils.dataset_process import dataset_gen_train
from utils.reward import format_reward_func, thinking_reward_func, accuracy_reward_type, accuracy_reward_action


DS_CONFIG = "ds_z3_offload_config.json"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
os.environ["SWANLAB_PROJECT"] = ""
run_name_swanlab = ""
output_dir_check = ""
output_dir_final = ""
dataset_complete_path = ""
model_id = ""
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
compute_dtype = torch.bfloat16
min_pixels = 65536
max_pixels = 6553600
num_epochs = 1
grpo_generations = 1
per_device_batch_size = 
gradient_steps = 1
gpu_num = 1
rollout_n = (per_device_batch_size / grpo_generations) * gradient_steps * gpu_num

tokenizer = AutoProcessor.from_pretrained(
    model_id,
    fix_mistral_regex=True)
processor = AutoProcessor.from_pretrained(
    model_id,
    fix_mistral_regex=True,
    min_pixels=min_pixels,
    max_pixels=max_pixels)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map=device_map,
    dtype=compute_dtype,
    attn_implementation="flash_attention_2")
with open(dataset_complete_path, 'r', encoding='utf-8') as f:
    data = json.load(f)
train_ds = data[:2000]
train_eval_ds = ds[2100:2500]

def dataset_gen_train():
    for example in train_ds:
        results = []
        SYSTEM_PROMPT = r'''You are an intelligent agent designed to operate a smartphone interface to complete specific tasks.'''

        relative_path = ''
        image_path_ori = example['image_path']
        image_path = os.path.join(relative_path, image_path_ori)
        task = example['task']
        screen_description = example['description']
        if len(example['history']) > 10:
            num_history = len(example['history']) - 10
            history_json = example['history'][num_history:]
        else:
            history_json = example['history']
        intention = example['intention']
        instruction = example['instruction']
        action_type = example['action_type']
        if action_type == 'CLICK' and example['sam2_bbox'] != []:
            action_info = [example['sam2_bbox'][i:i + 2] for i in range(0, len(example['sam2_bbox']), 2)]
        elif action_type == 'LONG_PRESS' and example['sam2_bbox'] != []:
            action_info = example['action_info'][0]
        else:
            action_info = example['action_info']
        answer_dict = {
            "action_type": action_type,
            "action_info": action_info,
        }
        answer_json = json.dumps(answer_dict, ensure_ascii=False)
        history = json.dumps(history_json, ensure_ascii=False)
        answer = "<thinking>" + "<analysis>" + screen_description + "</analysis>" + "<reasoning>" + intention + "</reasoning>" + "<instruction>" + instruction + "</instruction>" + "</thinking>\n" + "<answer>" + answer_json + "</answer>"
        with Image.open(image_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            image = img.resize((1000, 1000))
            results.append({
                'prompt': [
                    {'role': 'system', 'content': [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {'role': 'user', 'content': [
                        {"type": "image", },
                        {"type": "text",
                         "text": "Your task is: " + task + history},
                    ]},
                ],
                'image': image,
                'solution': answer,
            })
        for single_out in results:
            yield single_out


def dataset_gen_eval():
    for example in train_eval_ds:
        results = []
        SYSTEM_PROMPT = r'''You are an intelligent agent designed to operate a smartphone interface to complete specific tasks.'''

        relative_path = ''
        image_path_ori = example['image_path']
        image_path = os.path.join(relative_path, image_path_ori)
        task = example['task']
        screen_description = example['description']
        history_json = example['history']
        intention = example['intention']
        instruction = example['instruction']
        action_type = example['action_type']
        if action_type == 'CLICK' and example['sam2_bbox'] != []:
            action_info = [example['sam2_bbox'][i:i + 2] for i in range(0, len(example['sam2_bbox']), 2)]
        elif action_type == 'LONG_PRESS' and example['sam2_bbox'] != []:
            action_info = example['action_info'][0]
        else:
            action_info = example['action_info']
        answer_dict = {
            "action_type": action_type,
            "action_info": action_info,
        }
        answer_json = json.dumps(answer_dict, ensure_ascii=False)
        history = json.dumps(history_json, ensure_ascii=False)
        answer = "<thinking>" + "<analysis>" + screen_description + "</analysis>" + "<reasoning>" + intention + "</reasoning>" + "<instruction>" + instruction + "</instruction>" + "</thinking>\n" + "<answer>" + answer_json + "</answer>"
        with Image.open(image_path) as img:
            if img.mode != 'RGB':
                img = img.convert('RGB')
            image = img.resize((1000, 1000))
            results.append({
                'prompt': [
                    {'role': 'system', 'content': [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {'role': 'user', 'content': [
                        {"type": "image", },
                        {"type": "text",
                         "text": "Your task is: " + task + history},
                    ]},
                ],
                'image': image,
                'solution': answer,
            })
        for single_out in results:
            yield single_out


model.enable_input_require_grads()
model.gradient_checkpointing_enable()

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=32,  # Rank
    lora_alpha=64,
    target_modules="all-linear",
    bias="none",
    lora_dropout=0.05,  # Conventional
)

training_args = GRPOConfig(
    use_vllm=True,
    vllm_mode="colocate",
    learning_rate=2e-5,
    adam_beta1=0.9,
    adam_beta2=0.95,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch_8bit",  # {adamw_torch, adamw_torch_8bit, adamw_8bit}
    beta=0.001,
    logging_steps=1,
    bf16=True,
    fp16=False,
    per_device_train_batch_size=per_device_batch_size,  # keep same with num_generations
    gradient_accumulation_steps=gradient_steps,  # Increase to 4 for smoother training
    num_generations=grpo_generations, 
    max_prompt_length=6144,
    max_completion_length=1024,
    num_train_epochs=num_epochs,
    save_steps=250,
    # importance_sampling_level="sequence",
    max_grad_norm=1,
    report_to="swanlab",  # Can use Weights & Biases
    run_name=run_name_swanlab,
    output_dir=output_dir_check,
    epsilon_high=0.3,
    deepspeed=DS_CONFIG,
    gradient_checkpointing=True,
    disable_tqdm=False,  # enable not to kill the progress / ban the tqdm
    eval_steps=250,
    eval_strategy="steps", 
    do_eval=True, 
)

model = get_peft_model(model, peft_config)

dataset_train = Dataset.from_generator(dataset_gen_train)
dataset_eval = Dataset.from_generator(dataset_gen_eval)
reward_funcs = [format_reward_func, thinking_reward_func, accuracy_reward_type, accuracy_reward_action]
num_train_dataset = int(num_epochs * len(dataset_train) / rollout_n)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    args=training_args,
    peft_config=None,
    train_dataset=dataset_train,
    eval_dataset=dataset_eval,
    num_dataset=num_train_dataset
)

print("Starting training...")
trainer.train()

print("Training finished. Saving model...")
trainer.save_model(output_dir_final)
