import os
from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor
)
import torch
from datasets import Dataset
from PIL import Image
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, TaskType
import warnings
import json
import io
import random

warnings.filterwarnings("ignore", category=UserWarning)

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["SWANLAB_PROJECT"] = ""
run_name_swanlab = ""
output_dir_check = ""
output_dir_final = ""
model_id = ""
compute_dtype = torch.float16
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}

tokenizer = AutoProcessor.from_pretrained(model_id,
                                          fix_mistral_regex=True)
min_pixels = 65536
max_pixels = 6553600
processor = AutoProcessor.from_pretrained(model_id,
                                          fix_mistral_regex=True,
                                          min_pixels=min_pixels,
                                          max_pixels=max_pixels, use_fast=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id,
                                                        device_map=device_map if device_map else "auto",
                                                        dtype=compute_dtype,
                                                        attn_implementation="flash_attention_2")
dataset_path_complete = ""
with open(dataset_path_complete, 'r', encoding='utf-8') as f:
    data_0 = json.load(f)
data_set = data_0
dataset_path_eval = ""
with open(dataset_path_eval, 'r', encoding='utf-8') as f:
    data_2 = json.load(f)
data_set_eval = data_2
ds = data_set[:1600]
ds_eval = data_set_eval[:160]


def get_prompt_two_step_sft(example):
    results = []
    SYSTEM_PROMPT = r'''You are an intelligent agent designed to operate a smartphone interface to complete specific tasks. '''
    relative_path = ''
    image_path_ori = example['image_path']
    image_path = os.path.join(relative_path, image_path_ori)
    task = example['task']
    screen_description = example['description']
    if len(example['history']) > 10:
        num_history = len(example['history']) - 10
        history_json = example['history'][num_history:]
    else:
        history_json = example['history']
    intention = example['intention']
    instruction = example['instruction']
    action_type = example['action_type']
    action_info = example['action_info']
    answer_dict = {
        "action_type": action_type,
        "action_info": action_info,
    }
    answer_json = json.dumps(answer_dict, ensure_ascii=False)
    history = json.dumps(history_json, ensure_ascii=False)
    answer = "<thinking>" + "<analysis>" + screen_description + "</analysis>" + "<reasoning>" + intention + "</reasoning>" + "<instruction>" + instruction + "</instruction>" + "</thinking>\n" + "<answer>" + answer_json + "</answer>"
    with Image.open(image_path) as img:
        if img.mode != 'RGB':
            img = img.convert('RGB')
        img = img.resize((1000, 1000))
        results.append({
                "messages": [
                    {'role': 'system', 'content': [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {"role": "user", "content": [
                        {"type": "image_pil", "image": img},
                        {"type": "text", "text": "Your task is: " + task + history},
                    ]},
                    {"role": "assistant", "content": [{"type": "text", "text": answer}]}
                ]
            })
    return results


def dataset_gen():
    for items in ds:
        multiple_out = get_prompt_two_step_sft(items)
        for single_out in multiple_out:
            yield single_out


def dataset_gen_eval():
    for items in ds_eval:
        multiple_out = get_prompt_two_step_sft(items)
        for single_out in multiple_out:
            yield single_out


dataset_train = Dataset.from_generator(dataset_gen)
dataset_eval = Dataset.from_generator(dataset_gen_eval)

training_args = SFTConfig(
    learning_rate=1e-6,
    adam_beta1=0.9,
    adam_beta2=0.95,
    weight_decay=0.1,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    optim="adamw_torch_8bit",  # {adamw_torch, adamw_torch_8bit, adamw_8bit}
    logging_steps=1,
    per_device_train_batch_size=16,  # keep same with num_generations
    gradient_accumulation_steps=1,  # Increase to 4 for smoother training
    num_train_epochs=2,  # Set to 1 for a full training run
    save_steps=3000,
    max_grad_norm=1,
    report_to="swanlab",  # Can use Weights & Biases
    run_name=run_name_swanlab,
    output_dir=output_dir_check,
    disable_tqdm=False,  # enable not to kill the progress / ban the tqdm
    eval_steps=200,  
    eval_strategy="steps",  
    do_eval=True,  
)

trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    train_dataset=dataset_train,
    eval_dataset=dataset_eval
)

trainer.train()
trainer.save_model(output_dir_final)