# sft_stage1.py
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import json
import random
import argparse
from typing import Dict, List

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset

from transformers import (
    AutoTokenizer,                  
    Qwen2_5_VLForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
    Trainer,
    TrainerCallback,
)
from peft import LoraConfig, TaskType, get_peft_model
from tqdm import tqdm


from qwen_vl_utils import process_vision_info


class ListDataset(TorchDataset):
    def __init__(self, examples: List[Dict]):
        self.examples = examples
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        return self.examples[idx]


def resolve_image_path(img_path: str, img_root: str) -> str:
    if not img_path:
        return img_path
    if os.path.isabs(img_path):
        return os.path.normpath(img_path)
    return os.path.normpath(os.path.join(img_root, img_path.lstrip("/\\")))


def build_instruction(processor, img_path: str, img_root: str,
                      txt_prefix: str = "COCO Yes:", resize_h: int = 280, resize_w: int = 280) -> Dict[str, torch.Tensor]:
    full_img_path = resolve_image_path(img_path, img_root)
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": full_img_path, "resized_height": resize_h, "resized_width": resize_w},
            {"type": "text", "text": txt_prefix},
        ],
    }]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs,
                       padding=True, return_tensors="pt")
    out = {}
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor) and v.dim() > 0 and v.size(0) == 1:
            out[k] = v[0]
        else:
            out[k] = v
    return out


def pack_sft_example(tokenizer, inst: Dict[str, torch.Tensor], target_text: str, max_len: int = 8192) -> Dict[str, torch.Tensor]:

    resp_ids = tokenizer(target_text, add_special_tokens=False)["input_ids"]


    im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    if im_end_id is None or im_end_id == tokenizer.unk_token_id:
        im_end_id = tokenizer.eos_token_id
    if not resp_ids or resp_ids[-1] != im_end_id:
        resp_ids = resp_ids + [im_end_id]

    input_ids = torch.tensor(inst["input_ids"].tolist() + resp_ids, dtype=torch.long)
    attention_mask = torch.tensor(inst["attention_mask"].tolist() + [1] * len(resp_ids), dtype=torch.long)

    labels = torch.full((input_ids.size(0),), -100, dtype=torch.long)
    Lp = len(inst["input_ids"])
    labels[Lp: Lp + len(resp_ids)] = torch.tensor(resp_ids, dtype=torch.long)


    if input_ids.numel() > max_len:

        room = max_len - len(resp_ids)
        prefix = inst["input_ids"][:max(room, 0)].tolist()
        input_ids = torch.tensor(prefix + resp_ids, dtype=torch.long)
        attention_mask = torch.tensor(inst["attention_mask"][:len(prefix)].tolist() + [1]*len(resp_ids), dtype=torch.long)

        labels = torch.full((len(input_ids),), -100, dtype=torch.long)
        labels[len(prefix):] = torch.tensor(resp_ids, dtype=torch.long)


    image_grid_thw = inst["image_grid_thw"]
    if image_grid_thw.dim() == 3:
        image_grid_thw = image_grid_thw.squeeze(0)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "pixel_values": inst["pixel_values"],
        "image_grid_thw": image_grid_thw,
    }


def load_sft_stage1_dataset(gentle_json_path: str, processor, tokenizer, img_root: str,
                            max_len: int = 8192, neg_ratio: float = 0.3) -> TorchDataset:

    with open(gentle_json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    rng = random.Random(123)
    examples: List[Dict] = []
    for item in tqdm(data, desc="Building SFT Stage-1 dataset"):
        conv = item.get("conversations", [])
        img_path, y_pos, y_neg = None, None, None
        for turn in conv:
            if turn.get("from") == "user":
                v = turn.get("value", "")
                if "<|vision_start|>" in v and "<|vision_end|>" in v:
                    try:
                        img_path = v.split("<|vision_start|>")[1].split("<|vision_end|>")[0]
                    except Exception:
                        img_path = None
            elif turn.get("from") == "assistant" and y_pos is None:
                y_pos = turn.get("value", "")
            elif turn.get("from") == "gentle_negative" and y_neg is None:
                y_neg = turn.get("value", "")

        if img_path is None or y_pos is None:
            continue

        inst = build_instruction(processor, img_path, img_root)

        ex_pos = pack_sft_example(tokenizer, inst, y_pos, max_len)
        ex_pos["label_type"] = "pos"
        examples.append(ex_pos)

        if y_neg and rng.random() < max(0.0, min(1.0, neg_ratio)):
            ex_neg = pack_sft_example(tokenizer, inst, y_neg, max_len)
            ex_neg["label_type"] = "neg"
            examples.append(ex_neg)

    return ListDataset(examples)


class DataCollatorVL:
    def __init__(self, pad_token_id: int):
        self.pad_token_id = pad_token_id

    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        max_len = max(f["input_ids"].size(0) for f in features)
        bs = len(features)

        input_ids = torch.full((bs, max_len), self.pad_token_id, dtype=torch.long)
        attention_mask = torch.zeros((bs, max_len), dtype=torch.bool)
        labels = torch.full((bs, max_len), -100, dtype=torch.long)

        for i, f in enumerate(features):
            L = f["input_ids"].size(0)
            input_ids[i, :L] = f["input_ids"]
            attention_mask[i, :L] = f["attention_mask"].bool()
            labels[i, :L] = f["labels"]

        pixel_values = torch.stack([f["pixel_values"] for f in features], dim=0)
        image_grid_thw = torch.stack([f["image_grid_thw"] for f in features], dim=0)

        label_type = torch.tensor([1 if f.get("label_type") == "pos" else 0 for f in features], dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw,
            "label_type": label_type,
        }


@torch.no_grad()
def seq_logprob_from_batch(model, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
    was_train = model.training
    model.eval()
    out = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        pixel_values=batch["pixel_values"],
        image_grid_thw=batch["image_grid_thw"],
    )
    log_probs = F.log_softmax(out.logits, dim=-1)  # [B, L, V]
    labels = batch["labels"]                        # [B, L]

    shift_log_probs = log_probs[:, :-1, :]
    shift_labels = labels[:, 1:]
    mask = (shift_labels != -100)

    if mask.any():
        sel = shift_log_probs[mask]                                 # [Ntok, V]
        tgt = shift_labels[mask]                                    # [Ntok]
        token_logp = sel.gather(-1, tgt.unsqueeze(-1)).squeeze(-1)  # [Ntok]
        B = labels.size(0)
        lp = torch.zeros(B, device=labels.device)
        row_idx = torch.arange(B, device=labels.device).unsqueeze(1).expand_as(mask)
        lp.scatter_add_(0, row_idx[mask], token_logp)
    else:
        lp = torch.zeros(labels.size(0), device=labels.device)

    if was_train:
        model.train()
    return lp


class LearningDynamicsCallback(TrainerCallback):
    def __init__(self, trainer: Trainer, observer_ds: TorchDataset, collator: DataCollatorVL,
                 sample_size: int = 8, log_every_steps: int = 10, seed: int = 42):
        self.trainer = trainer
        self.observer_ds = observer_ds
        self.collator = collator
        self.sample_size = sample_size
        self.log_every = log_every_steps
        self.rng = random.Random(seed)
        self.prev_lp = None
        self.prev_batch = None
        self._was_training = True

    def _sample_batch(self) -> Dict[str, torch.Tensor]:
        idxs = self.rng.sample(range(len(self.observer_ds)), k=min(self.sample_size, len(self.observer_ds)))
        feats = [self.observer_ds[i] for i in idxs]
        batch = self.collator(feats)
        device = next(self.trainer.model.parameters()).device
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                batch[k] = v.to(device)
        return batch

    def on_step_begin(self, args, state, control, **kwargs):
        if state.global_step % self.log_every != 0:
            return
        self.prev_batch = self._sample_batch()
        self._was_training = self.trainer.model.training
        self.prev_lp = seq_logprob_from_batch(self.trainer.model, self.prev_batch)

    def on_step_end(self, args, state, control, **kwargs):
        if self.prev_lp is None or self.prev_batch is None:
            return
        curr_lp = seq_logprob_from_batch(self.trainer.model, self.prev_batch)
        delta = (curr_lp - self.prev_lp)

        lt = self.prev_batch["label_type"]
        pos_delta = delta[lt == 1]
        neg_delta = delta[lt == 0]

        logs = {
            "ld/delta_mean": float(delta.mean().item()),
            "ld/delta_std": float(delta.std(unbiased=False).item()) if delta.numel() > 1 else 0.0,
        }
        if pos_delta.numel() > 0:
            logs["ld/pos_delta_mean"] = float(pos_delta.mean().item())
        if neg_delta.numel() > 0:
            logs["ld/neg_delta_mean"] = float(neg_delta.mean().item())

        self.trainer.log(logs)
        self.prev_lp, self.prev_batch = None, None
        if self._was_training:
            self.trainer.model.train()


class WeightedTrainer(Trainer):
    def __init__(self, *args, neg_weight: float = 0.2, **kwargs):
        super().__init__(*args, **kwargs)
        self.neg_weight = neg_weight

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            pixel_values=inputs["pixel_values"],
            image_grid_thw=inputs["image_grid_thw"],
        )
        logits = outputs.logits                          # [B, L, V]
        labels = inputs["labels"]                        # [B, L]

        shift_logits = logits[:, :-1, :].contiguous()
        shift_labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
        tok_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1)).view(shift_labels.size())
        valid = (shift_labels != -100).float()
        tok_per_ex = valid.sum(dim=1).clamp_min(1.0)
        loss_ex = (tok_loss * valid).sum(dim=1) / tok_per_ex  # [B]

        lt = inputs.get("label_type")  # 1=pos, 0=neg
        if lt is None:
            weights = torch.ones_like(loss_ex)
        else:
            weights = torch.where(lt == 1, torch.ones_like(loss_ex),
                                  torch.full_like(loss_ex, self.neg_weight))

        loss = (loss_ex * weights).sum() / weights.sum().clamp_min(1.0)

        if return_outputs:
            return loss, outputs
        return loss


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    parser.add_argument("--gentle_json", type=str, required=True)
    parser.add_argument("--img_root", type=str, default=".")
    parser.add_argument("--output_dir", type=str, default="output/Qwen2.5-VL-7B-SFT1")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--grad_accum", type=int, default=4)
    parser.add_argument("--epochs", type=int, default=2)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--lora_rank", type=int, default=64)
    parser.add_argument("--lora_alpha", type=int, default=16)
    parser.add_argument("--lora_dropout", type=float, default=0.05)
    parser.add_argument("--max_len", type=int, default=8192)
    parser.add_argument("--ld_every", type=int, default=10)
    parser.add_argument("--ld_size", type=int, default=8)
    parser.add_argument("--neg_weight", type=float, default=0.2)
    parser.add_argument("--neg_ratio", type=float, default=0.3)
    args = parser.parse_args()

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if torch.cuda.is_available():
        torch.cuda.set_device(local_rank)

    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    dtype = torch.bfloat16 if use_bf16 else torch.float16

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model, trust_remote_code=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    processor = AutoProcessor.from_pretrained(args.pretrained_model)

    base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.pretrained_model,
        torch_dtype=dtype,
        device_map={"": local_rank} if torch.cuda.is_available() else None,
        low_cpu_mem_usage=True,
    )

    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
        inference_mode=False, r=args.lora_rank, lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout, bias="none",
    )
    model = get_peft_model(base, lora_cfg)

    train_ds = load_sft_stage1_dataset(args.gentle_json, processor, tokenizer,
                                       img_root=args.img_root, max_len=args.max_len, neg_ratio=args.neg_ratio)
    collator = DataCollatorVL(pad_token_id=tokenizer.pad_token_id)

    targs = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        logging_steps=10,
        save_steps=200,
        fp16=not use_bf16,
        bf16=use_bf16,
        report_to="none",
        dataloader_num_workers=4,
        dataloader_pin_memory=True,
        ddp_find_unused_parameters=False,
    )

    trainer = WeightedTrainer(
        model=model,
        args=targs,
        train_dataset=train_ds,
        data_collator=collator,
        neg_weight=args.neg_weight,
    )

    n_observe = min(256, len(train_ds))
    rng = random.Random(123)
    idxs = list(range(len(train_ds)))
    rng.shuffle(idxs)
    observer_pool = ListDataset([train_ds[i] for i in idxs[:n_observe]])

    ld_cb = LearningDynamicsCallback(
        trainer=trainer,
        observer_ds=observer_pool,
        collator=collator,
        sample_size=args.ld_size,
        log_every_steps=args.ld_every,
    )
    trainer.add_callback(ld_cb)
    trainer.train()


if __name__ == "__main__":
    main()