import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,7"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
from transformers import (
    Idefics2Processor,
    Idefics2ForConditionalGeneration,
    TrainingArguments,
    Trainer,
    BatchEncoding
)
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training

_IGNORE_INDEX = -100

from torchvision import transforms

class TimeSeriesChatDataset(Dataset):
    def __init__(self, jsonl_path, image_folder, processor):
        self.data = [json.loads(line) for line in open(jsonl_path)]
        self.image_folder = Path(image_folder)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = self.image_folder / item["image"]
        image = Image.open(image_path).convert("RGB")

        messages = item["conversations"]
        chat = []
        for m in messages:
            if m["from"] == "human":
                clean_text = m["value"].replace("<image>", "").strip()
                chat.append({
                    "role": "user",
                    "content": [{"type": "image"}, {"type": "text", "text": clean_text}]
                })
            else:
                chat.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": m["value"].strip()}]
                })

        text = self.processor.apply_chat_template(chat, add_generation_prompt=False)

        #  processor (which also resizes/normalizes internally)
        inputs = self.processor(
            text=text,
            images=[image],
            return_tensors="pt",
            padding=True,
            do_resize=True,
            size={"height": 224, "width": 224}
        )

        labels = inputs.input_ids.clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = _IGNORE_INDEX
        labels[labels == self.processor.tokenizer.convert_tokens_to_ids("<image>")] = _IGNORE_INDEX

        return {
            "input_ids": inputs.input_ids[0],
            "labels": labels[0],
            "attention_mask": inputs.attention_mask[0],
            "pixel_values": inputs.pixel_values[0],  # will now be same shape
        }

def collate_fn(batch):
    return BatchEncoding({
        "input_ids": pad_sequence([b["input_ids"] for b in batch], batch_first=True, padding_value=processor.tokenizer.pad_token_id),
        "labels": pad_sequence([b["labels"] for b in batch], batch_first=True, padding_value=_IGNORE_INDEX),
        "attention_mask": pad_sequence([b["attention_mask"] for b in batch], batch_first=True, padding_value=0),
        "pixel_values": torch.stack([b["pixel_values"] for b in batch])  # now all same shape
    })

def main():


    model_name = "HuggingFaceM4/idefics2-8b"
    jsonl_path = "/home/ubuntu/projects/time_series_main/idefics2/data/tsqa_train_data_all.jsonl"
    image_dir = "/home/ubuntu/projects/time_series_main/data/plots"
    output_dir = "./idefics2_captioning_output"

    global processor
    processor = Idefics2Processor.from_pretrained(model_name)

    model = Idefics2ForConditionalGeneration.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        attn_implementation="eager",
    )

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        init_lora_weights="gaussian",
        use_dora=True,
        target_modules=r".*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$",
        task_type=TaskType.CAUSAL_LM,
    )

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    dataset = TimeSeriesChatDataset(jsonl_path, image_dir, processor)

    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=12,
        learning_rate=2e-5,
        num_train_epochs=3,
        bf16=True,
        logging_steps=10,
        save_steps=500,
        save_total_limit=2,
        remove_unused_columns=False,
        report_to="none",
        dataloader_num_workers=4,
        dataloader_pin_memory=False,
        gradient_checkpointing=True,
        ddp_find_unused_parameters=False, 
    )

    trainer = Trainer(
        model=model, 
        args=training_args,
        train_dataset=dataset,
        data_collator=collate_fn,
    )


    trainer.train()
    if trainer.is_world_process_zero():
        trainer.save_model(output_dir)

if __name__ == "__main__":
    main()