import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import json
import math
import random
import argparse
from typing import Dict, List, Tuple, Optional

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm

from transformers import (
    AutoTokenizer,
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
    Trainer,
    TrainerCallback,
)
from peft import LoraConfig, TaskType, get_peft_model, PeftModel

from qwen_vl_utils import process_vision_info
import difflib

class ListDatasetDPO(TorchDataset):
    def __init__(self, records: List[Dict]):
        self.records = records
    def __len__(self):
        return len(self.records)
    def __getitem__(self, idx):
        return self.records[idx]


def resolve_image_path(img_path: str, img_root: str) -> str:
    if not img_path:
        return img_path
    if os.path.isabs(img_path):
        return os.path.normpath(img_path)
    return os.path.normpath(os.path.join(img_root, img_path.lstrip("/\\")))


def build_instruction(processor, img_path: str, img_root: str,
                      txt_prefix: str = "COCO Yes:", resize_h: int = 280, resize_w: int = 280) -> Dict[str, torch.Tensor]:
    full_img = resolve_image_path(img_path, img_root)
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": full_img, "resized_height": resize_h, "resized_width": resize_w},
            {"type": "text", "text": txt_prefix},
        ],
    }]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs,
                       padding=True, return_tensors="pt")
    out = {}
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor) and v.dim() > 0 and v.size(0) == 1:
            out[k] = v[0]
        else:
            out[k] = v
    return out


def load_dpo_dataset(gentle_json: str, processor, img_root: str,
                     resize_h: int = 280, resize_w: int = 280) -> TorchDataset:
    with open(gentle_json, "r", encoding="utf-8") as f:
        data = json.load(f)
    records: List[Dict] = []
    for item in tqdm(data, desc="Building DPO dataset"):
        conv = item.get("conversations", [])
        img_path, y_pos, y_neg = None, None, None
        for turn in conv:
            if turn.get("from") == "user":
                v = turn.get("value", "")
                if "<|vision_start|>" in v and "<|vision_end|>" in v:
                    try:
                        img_path = v.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
                    except Exception:
                        img_path = None
            elif turn.get("from") == "assistant" and y_pos is None:
                y_pos = turn.get("value", "")
            elif turn.get("from") == "gentle_negative" and y_neg is None:
                y_neg = turn.get("value", "")

        if img_path is None or y_pos is None:
            continue

        inst = build_instruction(processor, img_path, img_root, resize_h=resize_h, resize_w=resize_w)
        rec = {
            "input_ids": inst["input_ids"],
            "attention_mask": inst["attention_mask"],
            "pixel_values": inst["pixel_values"],
            "image_grid_thw": inst["image_grid_thw"].squeeze(0) if inst["image_grid_thw"].dim() == 3 else inst["image_grid_thw"],
            "y_pos_text": y_pos,
            "y_neg_text": y_neg if y_neg is not None else "",
            "img_rel_path": img_path,
        }
        records.append(rec)
    return ListDatasetDPO(records)


class DPOCollator:
    def __init__(self, pad_token_id: int):
        self.pad_token_id = pad_token_id

    def _to_tensor(self, x, dtype=None):
        if isinstance(x, torch.Tensor):
            return x.to(dtype=dtype) if dtype is not None else x
        return torch.tensor(x, dtype=dtype)

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        feats_t = []
        for f in features:
            tf = {
                "input_ids": self._to_tensor(f["input_ids"], dtype=torch.long),
                "attention_mask": self._to_tensor(f["attention_mask"], dtype=torch.bool),
                "pixel_values": self._to_tensor(f["pixel_values"], dtype=torch.float32),
                "image_grid_thw": self._to_tensor(f["image_grid_thw"], dtype=torch.long),
            }
            tf["y_pos_text"] = f["y_pos_text"]
            tf["y_neg_text"] = f["y_neg_text"]
            tf["img_rel_path"] = f.get("img_rel_path", "")
            feats_t.append(tf)

        max_len = max(tf["input_ids"].size(0) for tf in feats_t)
        bs = len(feats_t)
        input_ids = torch.full((bs, max_len), self.pad_token_id, dtype=torch.long)
        attention_mask = torch.zeros((bs, max_len), dtype=torch.bool)
        for i, tf in enumerate(feats_t):
            L = tf["input_ids"].size(0)
            input_ids[i, :L] = tf["input_ids"]
            attention_mask[i, :L] = tf["attention_mask"]

        pixel_values = torch.stack([tf["pixel_values"] for tf in feats_t], dim=0)
        image_grid_thw = torch.stack([tf["image_grid_thw"] for tf in feats_t], dim=0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw,
            "y_pos_texts": [tf["y_pos_text"] for tf in feats_t],
            "y_neg_texts": [tf["y_neg_text"] for tf in feats_t],
            "img_rel_paths": [tf["img_rel_path"] for tf in feats_t],
        }


def seq_logprob_for_targets(model,
                            text_tokenizer,
                            inst_ids: torch.Tensor,
                            inst_attn: torch.Tensor,
                            pixel_values: torch.Tensor,
                            image_grid_thw: torch.Tensor,
                            target_texts: List[str],
                            max_len: int = 8192) -> Tuple[torch.Tensor, torch.Tensor]:

    device = inst_ids.device
    B = inst_ids.size(0)

    tok = text_tokenizer(list(target_texts), add_special_tokens=False)
    tgt_id_lists = tok["input_ids"]
    Lp = inst_attn.long().sum(dim=1)  # [B]

    T = torch.tensor([len(x) for x in tgt_id_lists], device=device, dtype=torch.long)
    total_len = torch.clamp(Lp + T, max=max_len)
    Lmax = int(total_len.max().item())

    pad_id = text_tokenizer.pad_token_id
    x_ids  = torch.full((B, Lmax), pad_id, dtype=torch.long, device=device)
    x_mask = torch.zeros((B, Lmax), dtype=torch.bool, device=device)

    batch_idx_list, prev_pos_idx_list, tok_idx_list = [], [], []

    for i in range(B):
        Lp_i = int(Lp[i].item())
        tgt_i = torch.tensor(tgt_id_lists[i], dtype=torch.long, device=device)
        avail = min(Lp_i + tgt_i.numel(), Lmax)
        Lp_use = min(Lp_i, avail)
        T_use  = max(0, avail - Lp_use)

        if Lp_use > 0:
            x_ids[i, :Lp_use]  = inst_ids[i, :Lp_use]
            x_mask[i, :Lp_use] = inst_attn[i, :Lp_use]
        if T_use > 0:
            x_ids[i, Lp_use:Lp_use+T_use]  = tgt_i[:T_use]
            x_mask[i, Lp_use:Lp_use+T_use] = True

            start_prev = Lp_use - 1
            if start_prev >= 0:
                prev_pos = torch.arange(start_prev, start_prev + T_use, device=device, dtype=torch.long)
                tok_idx  = tgt_i[:T_use]
                b        = torch.full((T_use,), i, device=device, dtype=torch.long)
                batch_idx_list.append(b)
                prev_pos_idx_list.append(prev_pos)
                tok_idx_list.append(tok_idx)
            else:
                if T_use > 1:
                    prev_pos = torch.arange(0, T_use-1, device=device, dtype=torch.long)
                    tok_idx  = tgt_i[1:T_use]
                    b        = torch.full((T_use-1,), i, device=device, dtype=torch.long)
                    batch_idx_list.append(b)
                    prev_pos_idx_list.append(prev_pos)
                    tok_idx_list.append(tok_idx)

    if not batch_idx_list:
        return torch.zeros(B, device=device), torch.zeros(B, dtype=torch.long, device=device)

    batch_idx = torch.cat(batch_idx_list)
    prev_pos  = torch.cat(prev_pos_idx_list)
    tok_idx   = torch.cat(tok_idx_list)

    out = model(input_ids=x_ids, attention_mask=x_mask,
                pixel_values=pixel_values, image_grid_thw=image_grid_thw)
    log_probs = F.log_softmax(out.logits, dim=-1)  # [B, Lmax, V]

    selected = log_probs[batch_idx, prev_pos, tok_idx]  # [Ntok]
    sum_logp = torch.zeros(B, device=device, dtype=selected.dtype)
    sum_logp.scatter_add_(0, batch_idx, selected)

    ntoks = torch.zeros(B, device=device, dtype=torch.long)
    ntoks.scatter_add_(0, batch_idx, torch.ones_like(batch_idx, dtype=torch.long))
    return sum_logp, ntoks





def jaccard_overlap_ids(a: List[int], b: List[int]) -> float:
    A, B = set(a), set(b)
    if not A and not B: return 1.0
    return len(A & B) / max(1, len(A | B))

def build_diff_masks_for_batch(
    pos_id_lists: List[List[int]],
    neg_id_lists: List[List[int]],
    T_pos_max: int,
    T_neg_max: int,
    device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    pos_masks, neg_masks, decays = [], [], []
    for pos_ids, neg_ids in zip(pos_id_lists, neg_id_lists):
        sm = difflib.SequenceMatcher(a=pos_ids, b=neg_ids, autojunk=False)
        pm = [0]*len(pos_ids); nm = [0]*len(neg_ids)
        for tag, i1, i2, j1, j2 in sm.get_opcodes():
            if tag != 'equal':
                for i in range(i1, i2): pm[i] = 1
                for j in range(j1, j2): nm[j] = 1
        # pad
        pm = (pm[:T_pos_max] + [0]*max(0, T_pos_max - len(pm)))
        nm = (nm[:T_neg_max] + [0]*max(0, T_neg_max - len(nm)))
        pos_masks.append(pm)
        neg_masks.append(nm)
        jac = jaccard_overlap_ids(pos_ids, neg_ids)
        decays.append(max(0.1, 1.0 - jac))
    device = next(iter(torch.nn.Module().parameters()), torch.tensor(0)).device if torch.cuda.is_available() else torch.device('cpu')
    pos_mask = torch.tensor(pos_masks, dtype=torch.float32, device=device)
    neg_mask = torch.tensor(neg_masks, dtype=torch.float32, device=device)
    sim_decay = torch.tensor(decays, dtype=torch.float32, device=device)
    return pos_mask, neg_mask, sim_decay


def seq_token_logprobs_for_targets(
    model,
    text_tokenizer,
    inst_ids: torch.Tensor,
    inst_attn: torch.Tensor,
    pixel_values: torch.Tensor,
    image_grid_thw: torch.Tensor,
    target_id_lists: List[List[int]],
    max_len: int = 8192,
) -> Tuple[torch.Tensor, torch.Tensor]:
    device = inst_ids.device
    B = inst_ids.size(0)
    Lp = inst_attn.long().sum(dim=1)  # [B]
    T_lens = torch.tensor([len(t) for t in target_id_lists], device=device, dtype=torch.long)
    T_max = int(T_lens.max().item()) if T_lens.numel() > 0 else 1

    total_len = torch.clamp(Lp + T_lens, max=max_len)
    Lmax = int(total_len.max().item()) if total_len.numel() > 0 else int(Lp.max().item())

    pad_id = text_tokenizer.pad_token_id
    x_ids  = torch.full((B, Lmax), pad_id, dtype=torch.long, device=device)
    x_mask = torch.zeros((B, Lmax), dtype=torch.bool, device=device)

    batch_idx_list, prev_pos_list, tok_idx_list, col_idx_list = [], [], [], []

    for i in range(B):
        Lp_i = int(Lp[i].item())
        tgt = torch.tensor(target_id_lists[i], dtype=torch.long, device=device) if len(target_id_lists[i])>0 else torch.tensor([], dtype=torch.long, device=device)
        avail = min(Lp_i + tgt.numel(), Lmax)
        Lp_use = min(Lp_i, avail)
        T_use = max(0, avail - Lp_use)

        if Lp_use > 0:
            x_ids[i, :Lp_use]  = inst_ids[i, :Lp_use]
            x_mask[i, :Lp_use] = inst_attn[i, :Lp_use]
        if T_use > 0:
            x_ids[i, Lp_use:Lp_use+T_use]  = tgt[:T_use]
            x_mask[i, Lp_use:Lp_use+T_use] = True

            start_prev = Lp_use - 1
            if start_prev >= 0:
                prev_pos = torch.arange(start_prev, start_prev + T_use, device=device, dtype=torch.long)
                cols     = torch.arange(0, T_use, device=device, dtype=torch.long)
                b        = torch.full((T_use,), i, device=device, dtype=torch.long)
                batch_idx_list.append(b)
                prev_pos_list.append(prev_pos)
                tok_idx_list.append(tgt[:T_use])
                col_idx_list.append(cols)
            else:
                if T_use > 1:
                    prev_pos = torch.arange(0, T_use-1, device=device, dtype=torch.long)
                    cols     = torch.arange(1, T_use, device=device, dtype=torch.long)
                    b        = torch.full((T_use-1,), i, device=device, dtype=torch.long)
                    batch_idx_list.append(b)
                    prev_pos_list.append(prev_pos)
                    tok_idx_list.append(tgt[1:T_use])
                    col_idx_list.append(cols)

    out = model(input_ids=x_ids, attention_mask=x_mask,
                pixel_values=pixel_values, image_grid_thw=image_grid_thw)
    log_probs = F.log_softmax(out.logits.float(), dim=-1)

    dtype_lp = log_probs.dtype  # float32
    tok_logp   = torch.zeros(B, T_max, dtype=dtype_lp, device=device)
    valid_mask = torch.zeros(B, T_max, dtype=torch.float32, device=device)

    if not batch_idx_list:
        return tok_logp, valid_mask

    batch_idx = torch.cat(batch_idx_list)  # [Ntok]
    prev_pos  = torch.cat(prev_pos_list)   # [Ntok]
    tok_idx   = torch.cat(tok_idx_list)    # [Ntok]
    col_idx   = torch.cat(col_idx_list)    # [Ntok]

    selected  = log_probs[batch_idx, prev_pos, tok_idx]  # float32

    tok_logp.index_put_((batch_idx, col_idx), selected)
    valid_mask.index_put_((batch_idx, col_idx), torch.ones_like(selected, dtype=torch.float32))
    return tok_logp, valid_mask




class DpoLearningDynamicsCallback(TrainerCallback):
    def __init__(self, trainer: Trainer, observer_ds: TorchDataset, collator: DPOCollator,
                 text_tokenizer, sample_size: int = 8, log_every_steps: int = 20):
        self.trainer = trainer
        self.observer_ds = observer_ds
        self.collator = collator
        self.sample_size = min(sample_size, len(observer_ds))
        self.log_every = log_every_steps
        self.text_tokenizer = text_tokenizer
        self.prev = None
        self._was_training = True
        self.rng = random.Random(123)

    def _sample_batch(self) -> Dict[str, torch.Tensor]:
        idxs = self.rng.sample(range(len(self.observer_ds)), k=self.sample_size)
        feats = [self.observer_ds[i] for i in idxs]
        batch = self.collator(feats)
        dev = next(self.trainer.model.parameters()).device
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(dev)
        return batch

    @torch.no_grad()
    def _calc_lp(self, batch):
        model = self.trainer.model
        pos_sum, pos_nt = seq_logprob_for_targets(
            model, self.text_tokenizer,
            batch["input_ids"], batch["attention_mask"],
            batch["pixel_values"], batch["image_grid_thw"],
            batch["y_pos_texts"]
        )
        neg_sum, neg_nt = seq_logprob_for_targets(
            model, self.text_tokenizer,
            batch["input_ids"], batch["attention_mask"],
            batch["pixel_values"], batch["image_grid_thw"],
            batch["y_neg_texts"]
        )
        return (pos_sum, pos_nt), (neg_sum, neg_nt)

    def on_step_begin(self, args, state, control, **kwargs):
        if state.global_step % self.log_every != 0:
            return
        self._was_training = self.trainer.model.training
        self.trainer.model.eval()
        batch = self._sample_batch()
        self.prev = (*self._calc_lp(batch), batch)

    def on_step_end(self, args, state, control, **kwargs):
        if self.prev is None:
            return
        (pos_sum0, pos_nt0), (neg_sum0, neg_nt0), batch = self.prev
        (pos_sum1, pos_nt1), (neg_sum1, neg_nt1) = self._calc_lp(batch)
        pos_delta = (pos_sum1/pos_nt1.clamp_min(1) - pos_sum0/pos_nt0.clamp_min(1))
        neg_delta = (neg_sum1/neg_nt1.clamp_min(1) - neg_sum0/neg_nt0.clamp_min(1))
        logs = {
            "ld/pos_dlogp_mean": float(pos_delta.mean().item()),
            "ld/neg_dlogp_mean": float(neg_delta.mean().item()),
        }
        self.trainer.log(logs)
        self.prev = None
        if self._was_training:
            self.trainer.model.train()


class GentleDPOTrainer(Trainer):
    def __init__(self, *args, text_tokenizer, ref_model,
                 beta: float = 0.1,
                 neg_mode: str = "mixed",      # 'dataset' | 'onpolicy' | 'mixed'
                 onpolicy_prob: float = 0.5,
                 sample_temp: float = 0.7,
                 sample_top_p: float = 0.9,
                 cooling_floor_logp: float = -8.0,
                 cooling_tau: float = 2.0,
                 neg_clip_abs: Optional[float] = None,
                 max_len: int = 8192,
                 processing_class=None, **kwargs):
        if processing_class is not None:
            kwargs["processing_class"] = processing_class
        super().__init__(*args, **kwargs)

        self.text_tokenizer = text_tokenizer
        self.ref_model = ref_model.eval()
        for p in self.ref_model.parameters():
            p.requires_grad_(False)

        self.beta = beta
        self.neg_mode = neg_mode
        self.onpolicy_prob = onpolicy_prob
        self.sample_temp = sample_temp
        self.sample_top_p = sample_top_p
        self.cooling_floor_logp = cooling_floor_logp
        self.cooling_tau = cooling_tau
        self.neg_clip_abs = neg_clip_abs
        self.max_len = max_len

        if getattr(self.model, "generation_config", None) is not None:
            if self.model.generation_config.pad_token_id is None:
                self.model.generation_config.pad_token_id = self.text_tokenizer.pad_token_id
            if self.model.generation_config.eos_token_id is None:
                self.model.generation_config.eos_token_id = self.text_tokenizer.eos_token_id

        self._ref_on_device = False

    @torch.no_grad()
    def _onpolicy_sample(self, batch) -> List[str]:
        out_ids = self.model.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            pixel_values=batch["pixel_values"],
            image_grid_thw=batch["image_grid_thw"],
            do_sample=True,
            temperature=self.sample_temp,
            top_p=self.sample_top_p,
            max_new_tokens=64,
            use_cache=True,
        )
        gen_trim = out_ids[:, batch["input_ids"].size(1):]
        texts = self.text_tokenizer.batch_decode(gen_trim, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        return [t.strip() for t in texts]

    def _avg_logp(self, sum_logp: torch.Tensor, ntoks: torch.Tensor) -> torch.Tensor:
        return (sum_logp / ntoks.clamp_min(1)).clamp(min=-100, max=10)

    def _cooling_weight(self, avg_logp_neg: torch.Tensor) -> torch.Tensor:
        x = (avg_logp_neg - self.cooling_floor_logp) / max(self.cooling_tau, 1e-6)
        return torch.sigmoid(x)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        if not self._ref_on_device:
            cur_dev = next(self.model.parameters()).device
            ref_dev = next(self.ref_model.parameters()).device
            if ref_dev != cur_dev:
                self.ref_model.to(cur_dev)
            self._ref_on_device = True

        inst_ids, inst_attn = inputs["input_ids"], inputs["attention_mask"]
        pix, thw = inputs["pixel_values"], inputs["image_grid_thw"]
        y_pos_texts = inputs["y_pos_texts"]
        y_neg_texts = inputs["y_neg_texts"]


        if self.neg_mode == "onpolicy":
            y_neg_now = self._onpolicy_sample(inputs)
        elif self.neg_mode == "mixed":
            onp_texts = self._onpolicy_sample(inputs)
            mask_m = torch.rand(len(y_neg_texts), device=inst_ids.device) < self.onpolicy_prob
            y_neg_now = [onp_texts[i] if bool(mask_m[i].item()) else y_neg_texts[i] for i in range(len(y_neg_texts))]
        else:
            y_neg_now = y_neg_texts


        pos_tok = self.text_tokenizer(list(y_pos_texts), add_special_tokens=False)["input_ids"]
        neg_tok = self.text_tokenizer(list(y_neg_now),  add_special_tokens=False)["input_ids"]
        T_pos_max = max((len(t) for t in pos_tok), default=1)
        T_neg_max = max((len(t) for t in neg_tok),  default=1)


        pos_logp_cur, pos_valid = seq_token_logprobs_for_targets(
            model, self.text_tokenizer, inst_ids, inst_attn, pix, thw, pos_tok, self.max_len
        )
        neg_logp_cur, neg_valid = seq_token_logprobs_for_targets(
            model, self.text_tokenizer, inst_ids, inst_attn, pix, thw, neg_tok, self.max_len
        )


        with torch.no_grad():
            pos_logp_ref, pos_valid_ref = seq_token_logprobs_for_targets(
                self.ref_model, self.text_tokenizer, inst_ids, inst_attn, pix, thw, pos_tok, self.max_len
            )
            neg_logp_ref, neg_valid_ref = seq_token_logprobs_for_targets(
                self.ref_model, self.text_tokenizer, inst_ids, inst_attn, pix, thw, neg_tok, self.max_len
            )


            
        def choose_mask(diff_mask, valid_mask):
            use_diff = (diff_mask.sum(dim=1, keepdim=True) > 0).float()
            return use_diff * diff_mask + (1.0 - use_diff) * valid_mask
        

        def masked_mean(x, m):
            s = (x * m).sum(dim=1)
            n = m.sum(dim=1).clamp_min(1.0)
            return s / n
        


        dev = next(model.parameters()).device
        pos_mask_diff, neg_mask_diff, sim_decay = build_diff_masks_for_batch(
            pos_tok, neg_tok, T_pos_max, T_neg_max, device=dev
        )
        pos_mask_diff = pos_mask_diff.to(dev); neg_mask_diff = neg_mask_diff.to(dev); sim_decay = sim_decay.to(dev)

        pos_eff_mask = pos_valid

        neg_eff_mask = choose_mask(neg_mask_diff, neg_valid)

        pos_avg_cur = masked_mean(pos_logp_cur, pos_eff_mask)
        pos_avg_ref = masked_mean(pos_logp_ref, pos_eff_mask)
        neg_avg_cur = masked_mean(neg_logp_cur, neg_eff_mask)
        neg_avg_ref = masked_mean(neg_logp_ref, neg_eff_mask)

        delta_pos = pos_avg_cur - pos_avg_ref
        delta_neg = neg_avg_cur - neg_avg_ref

        avg_logp_neg_for_cool = neg_avg_cur
        w_neg = self._cooling_weight(avg_logp_neg_for_cool)

        if self.neg_clip_abs is not None:
            delta_neg = torch.clamp(delta_neg, min=-abs(self.neg_clip_abs), max=abs(self.neg_clip_abs))

        s = self.beta * (delta_pos - w_neg * delta_neg)
        loss = -F.logsigmoid(s).mean()

        if return_outputs:
            outputs = {
                "delta_pos_mean": delta_pos.mean().detach(),
                "delta_neg_mean": delta_neg.mean().detach(),
                "w_neg_mean": w_neg.mean().detach()
            }
            return loss, outputs
        return loss



def load_ref_model(
    pretrained_model: str,
    ref_path: str,
    dtype,
    local_rank: int,
    set_eval: bool = True,   # True -> model.eval(); False -> model.train()
):

    is_lora = os.path.exists(os.path.join(ref_path, "adapter_config.json"))
    if is_lora:
        print("is_lora")
        base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            pretrained_model,
            torch_dtype=dtype,
            device_map={"": local_rank} if torch.cuda.is_available() else None,
            low_cpu_mem_usage=True,
        )
        ref_model = PeftModel.from_pretrained(base, ref_path)
    else:
        print("not_is_lora")
        ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            ref_path,
            torch_dtype=dtype,
            device_map={"": local_rank} if torch.cuda.is_available() else None,
            low_cpu_mem_usage=True,
        )

    if set_eval:
        ref_model.eval()                              
        for p in ref_model.parameters():
            p.requires_grad_(False)                     
    else:
        ref_model.train()

    return ref_model



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    parser.add_argument("--ref_model", type=str, required=True, help="SFT ckpt or LoRA ")
    parser.add_argument("--gentle_json", type=str, required=True)
    parser.add_argument("--img_root", type=str, default=".")
    parser.add_argument("--output_dir", type=str, default="output/Qwen2.5-VL-7B-DPO2")

    # LoRA & basic
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--grad_accum", type=int, default=8)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--resize_h", type=int, default=280)
    parser.add_argument("--resize_w", type=int, default=280)
    parser.add_argument("--max_len", type=int, default=8192)

    # DPO knobs
    parser.add_argument("--beta", type=float, default=0.1)
    parser.add_argument("--neg_mode", type=str, default="mixed", choices=["dataset", "onpolicy", "mixed"])
    parser.add_argument("--onpolicy_prob", type=float, default=0.5)
    parser.add_argument("--sample_temp", type=float, default=0.7)
    parser.add_argument("--sample_top_p", type=float, default=0.9)
    parser.add_argument("--cooling_floor_logp", type=float, default=-8.0)
    parser.add_argument("--cooling_tau", type=float, default=2.0)
    parser.add_argument("--neg_clip_abs", type=float, default=None)

    # dataloader
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--pin_mem", action="store_true", default=False)

    args = parser.parse_args()

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float16

    text_tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model, trust_remote_code=True)
    if text_tokenizer.pad_token_id is None:
        text_tokenizer.pad_token_id = text_tokenizer.eos_token_id
    processor = AutoProcessor.from_pretrained(args.pretrained_model)

    # base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    #     args.pretrained_model, torch_dtype=dtype,
    #     device_map={"": local_rank} if torch.cuda.is_available() else None,
    #     low_cpu_mem_usage=True,
    # )
    # lora_cfg = LoraConfig(
    #     task_type=TaskType.CAUSAL_LM,
    #     target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    #     inference_mode=False, r=args.lora_rank, lora_alpha=args.lora_alpha,
    #     lora_dropout=args.lora_dropout, bias="none",
    # )
    # model = load_ref_model(args.pretrained_model, args.ref_model, dtype, local_rank, set_eval=False)

    base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.pretrained_model,
        torch_dtype=dtype,
        device_map={"": local_rank} if torch.cuda.is_available() else None,
        low_cpu_mem_usage=True,
    )

    model = PeftModel.from_pretrained(
        base,
        args.ref_model,        
        is_trainable=True,  
    )



    ref_model = load_ref_model(args.pretrained_model, args.ref_model, dtype, local_rank, set_eval=True)


    ds = load_dpo_dataset(args.gentle_json, processor, args.img_root, args.resize_h, args.resize_w)
    collator = DPOCollator(pad_token_id=text_tokenizer.pad_token_id)

    targs = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        logging_steps=10,
        save_steps=100,
        remove_unused_columns=False,
        fp16=not use_bf16,
        bf16=use_bf16,
        report_to="none",
        dataloader_num_workers=args.num_workers,
        dataloader_pin_memory=args.pin_mem,
        ddp_find_unused_parameters=False,
    )

    trainer = GentleDPOTrainer(
        model=model,
        args=targs,
        train_dataset=ds,
        data_collator=collator,
        text_tokenizer=text_tokenizer,
        processing_class=processor,       
        ref_model=ref_model,
        beta=args.beta,
        neg_mode=args.neg_mode,
        onpolicy_prob=args.onpolicy_prob,
        sample_temp=args.sample_temp,
        sample_top_p=args.sample_top_p,
        cooling_floor_logp=args.cooling_floor_logp,
        cooling_tau=args.cooling_tau,
        neg_clip_abs=args.neg_clip_abs,
        max_len=args.max_len,
    )


    n_obs = min(128, len(ds))
    rng = random.Random(123)
    idxs = list(range(len(ds)))
    rng.shuffle(idxs)
    observer_pool = ListDatasetDPO([ds[i] for i in idxs[:n_obs]])

    ld_cb = DpoLearningDynamicsCallback(
        trainer=trainer,
        observer_ds=observer_pool,
        collator=collator,
        text_tokenizer=text_tokenizer,
        sample_size=min(8, len(observer_pool)),
        log_every_steps=20,
    )
    trainer.add_callback(ld_cb)

    trainer.train()


if __name__ == "__main__":
    main()
