

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "2,1"


import os
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import torch
torch.cuda.empty_cache()


from torch.utils.data import Dataset
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    BatchFeature,
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

_IGNORE_INDEX = -100
_MAX_TRAINING_LENGTH = 8192


class TimeSeriesCaptionDataset(Dataset):
    def __init__(self, jsonl_path, image_folder, processor):
        self.data = [json.loads(line) for line in open(jsonl_path)]
        self.image_folder = Path(image_folder)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = self.image_folder / item["image"]
        image = Image.open(image_path).convert("RGB")

        messages = item["conversations"]
        formatted_messages = []
        first = True
        for m in messages:
            role = "user" if m["from"] == "human" else "assistant"
            content = m["value"].replace("<image>", "").strip()
            if first and role == "user":
                content = "<|image_1|>" + content
                first = False
            formatted_messages.append({"role": role, "content": content})

        prompt = self.processor.tokenizer.apply_chat_template(
            formatted_messages, tokenize=False, add_generation_prompt=True
        )
        gt_answer = messages[-1]["value"].strip() + "<|end|><|endoftext|>"

        inputs = self.processor(prompt, images=[image], return_tensors="pt")
        answer_ids = self.processor.tokenizer(gt_answer, return_tensors="pt").input_ids
        input_ids = torch.cat([inputs.input_ids, answer_ids], dim=1)

        labels = torch.full_like(input_ids, _IGNORE_INDEX)
        labels[:, -answer_ids.shape[1]:] = answer_ids

        if input_ids.shape[1] > _MAX_TRAINING_LENGTH:
            input_ids = input_ids[:, :_MAX_TRAINING_LENGTH]
            labels = labels[:, :_MAX_TRAINING_LENGTH]
            if torch.all(labels == _IGNORE_INDEX).item():
                labels[:, -1] = self.processor.tokenizer.eos_token_id

        return {
            "input_ids": input_ids[0],
            "labels": labels[0],
            "input_image_embeds": inputs.input_image_embeds,
            "image_attention_mask": inputs.image_attention_mask,
            "image_sizes": inputs.image_sizes,
        }


def pad_sequence(sequences, padding_value=0):
    max_len = max(seq.size(0) for seq in sequences)
    padded = torch.full((len(sequences), max_len), padding_value, dtype=sequences[0].dtype)
    for i, seq in enumerate(sequences):
        padded[i, :seq.size(0)] = seq
    return padded


def cat_with_pad(tensors, dim, padding_value=0):
    ndim = tensors[0].dim()
    out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)]
    out_size[dim] = sum(t.shape[dim] for t in tensors)
    output = tensors[0].new_full(out_size, padding_value)
    index = 0
    for t in tensors:
        slices = [slice(0, t.shape[d]) for d in range(ndim)]
        slices[dim] = slice(index, index + t.shape[dim])
        output[slices] = t
        index += t.shape[dim]
    return output


def collate_fn(batch):
    input_ids = pad_sequence([b["input_ids"] for b in batch], padding_value=0)
    labels = pad_sequence([b["labels"] for b in batch], padding_value=_IGNORE_INDEX)
    attention_mask = (input_ids != 0).long()
    input_image_embeds = cat_with_pad([b["input_image_embeds"] for b in batch], dim=0)
    image_attention_mask = cat_with_pad([b["image_attention_mask"] for b in batch], dim=0)
    image_sizes = torch.cat([b["image_sizes"] for b in batch])

    return BatchFeature({
        "input_ids": input_ids,
        "labels": labels,
        "attention_mask": attention_mask,
        "input_image_embeds": input_image_embeds,
        "image_attention_mask": image_attention_mask,
        "image_sizes": image_sizes,
        "input_mode": 1,
    })


def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    print(f"[Rank {local_rank}] Using GPU: {torch.cuda.get_device_name(local_rank)}")

    model_name = "microsoft/Phi-4-multimodal-instruct"
    jsonl_path = "/home/ubuntu/projects/time_series_main/phi4/data/tsqa_train_data_all.jsonl"
    image_dir = "/home/ubuntu/projects/time_series_main/data/plots"
    output_dir = "./phi4_captioning_output"

    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )

    lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=[
            "qkv_proj", "o_proj", "gate_up_proj", "down_proj"
        ],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    model = model.cuda(local_rank)

    dataset = TimeSeriesCaptionDataset(jsonl_path, image_dir, processor)

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=12,
        learning_rate=2e-5,
        num_train_epochs=3,
        bf16=True,
        logging_steps=10,
        save_steps=500,
        save_total_limit=2,
        remove_unused_columns=False,
        report_to="none",
        dataloader_pin_memory=False,
        gradient_checkpointing=True,
        ddp_find_unused_parameters=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=collate_fn,
    )

    trainer.train()
    if trainer.is_world_process_zero():
        trainer.save_model(output_dir)


if __name__ == "__main__":
    import torch.distributed as dist
    dist.init_process_group(backend="nccl")
    main()