import os
import json
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    Idefics3ForConditionalGeneration,
    AutoProcessor,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_VISIBLE_DEVICES"] = "2,7"  

_IGNORE_INDEX = -100
MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct"
JSONL_PATH = "/home/ubuntu/projects/time_series_main/idefics2/data/tsqa_train_data_all.jsonl"
IMAGE_DIR = "/home/ubuntu/projects/time_series_main/data/plots"
OUTPUT_DIR = "./smolvlm_finetuned_output"

SYSTEM_MESSAGE = """You are a Vision Language Model specialized in analyzing time series charts.
Please provide concise interpretations of the data visualization shown in the image."""

class TimeSeriesVLMDataset(Dataset):
    def __init__(self, jsonl_path, image_folder, processor):
        self.data = [json.loads(line) for line in open(jsonl_path)]
        self.image_folder = Path(image_folder)
        self.processor = processor
        self.image_token_id = processor.tokenizer.additional_special_tokens_ids[
            processor.tokenizer.additional_special_tokens.index("<image>")
        ]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = self.image_folder / item["image"]
        image = Image.open(image_path).convert("RGB")

        messages = item["conversations"]
        chat = [
            {
                "role": "system",
                "content": [{"type": "text", "text": SYSTEM_MESSAGE}]
            }
        ]
        
        for m in messages:
            if m["from"] == "human":
                clean_text = m["value"].replace("<image>", "").strip()
                chat.append({
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": clean_text}
                    ]
                })
            else:
                chat.append({
                    "role": "assistant",
                    "content": [{"type": "text", "text": m["value"].strip()}]
                })

        text = self.processor.apply_chat_template(chat, tokenize=False)
        
        # Process inputs
        inputs = self.processor(
            text=text,
            images=[image],
            return_tensors="pt",
            padding=True,
        )
        
        labels = inputs.input_ids.clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = _IGNORE_INDEX
        labels[labels == self.image_token_id] = _IGNORE_INDEX
        
        return {
            "input_ids": inputs.input_ids[0],
            "attention_mask": inputs.attention_mask[0],
            "labels": labels[0],
            "pixel_values": inputs.pixel_values[0] if hasattr(inputs, "pixel_values") else None,
            "image_features": inputs.image_features[0] if hasattr(inputs, "image_features") else None,
        }

def collate_fn(batch):
    input_ids = torch.nn.utils.rnn.pad_sequence(
        [item["input_ids"] for item in batch], 
        batch_first=True, 
        padding_value=processor.tokenizer.pad_token_id
    )
    
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [item["attention_mask"] for item in batch], 
        batch_first=True, 
        padding_value=0
    )
    
    labels = torch.nn.utils.rnn.pad_sequence(
        [item["labels"] for item in batch], 
        batch_first=True, 
        padding_value=_IGNORE_INDEX
    )

    if batch[0]["pixel_values"] is not None:
        max_len = max(item["pixel_values"].shape[0] for item in batch)
        padded_pixel_values = []
        for item in batch:
            pv = item["pixel_values"]
            pad_len = max_len - pv.shape[0]
            if pad_len > 0:
                pad_tensor = torch.zeros((pad_len, *pv.shape[1:]), dtype=pv.dtype, device=pv.device)
                pv = torch.cat([pv, pad_tensor], dim=0)
            padded_pixel_values.append(pv)
        pixel_values = torch.stack(padded_pixel_values)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "pixel_values": pixel_values
        }
    else:
        max_len = max(item["image_features"].shape[0] for item in batch)
        padded_image_features = []
        for item in batch:
            feat = item["image_features"]
            pad_len = max_len - feat.shape[0]
            if pad_len > 0:
                pad_tensor = torch.zeros((pad_len, *feat.shape[1:]), dtype=feat.dtype, device=feat.device)
                feat = torch.cat([feat, pad_tensor], dim=0)
            padded_image_features.append(feat)
        image_features = torch.stack(padded_image_features)
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "image_features": image_features
        }

def main():
    import torch.distributed as dist

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    print("Loading model and processor ...")
    global processor
    processor = AutoProcessor.from_pretrained(MODEL_ID)

    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    torch.cuda.set_device(local_rank)

    if torch.cuda.device_count() > 1 and dist.is_available() and not dist.is_initialized():
        dist.init_process_group(backend="nccl")

    model = Idefics3ForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config,
        device_map={"": local_rank}, 
        _attn_implementation="eager",
    )

    peft_config = LoraConfig(
        r=16,
        lora_alpha=8,
        lora_dropout=0.05,
        target_modules=[
        "o_proj", "k_proj", "q_proj",
             "v_proj"
        ],
        use_dora=True,
        init_lora_weights="gaussian",
    )
    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    print("Loading dataset...")
    dataset = TimeSeriesVLMDataset(JSONL_PATH, IMAGE_DIR, processor)
    train_size = int(0.9 * len(dataset))
    eval_size = len(dataset) - train_size
    train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=12,
        learning_rate=2e-5,
        num_train_epochs=3,
        bf16=True,
        logging_steps=10,
        save_steps=200,
        save_total_limit=2,
        remove_unused_columns=False,
        report_to="tensorboard",
        dataloader_pin_memory=False,
        gradient_checkpointing=True,
        ddp_find_unused_parameters=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collate_fn,
    )

    print("Starting training...")
    trainer.train()

    from transformers.trainer_utils import is_main_process
    if is_main_process(local_rank):
        model.save_pretrained(OUTPUT_DIR)
        print(f"Model adapter saved to {OUTPUT_DIR}")

if __name__ == "__main__":
    main()