import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
os.environ["TOKENIZERS_PARALLELISM"] = "false" 

import json
import logging
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, Subset
from transformers import (
    AutoProcessor,
    TrainingArguments, 
    Trainer,
    MllamaForConditionalGeneration,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler("training.log")
    ]
)
logger = logging.getLogger(__name__)
import torch
print(torch.cuda.device_count())  
print(torch.cuda.current_device())  
print(torch.cuda.get_device_name(0))  

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "meta-llama/Llama-3.2-11B-Vision-Instruct"
TOKEN = ""
JSONL_PATH = "/home/ubuntu/projects/time_series_main/idefics2/data/tsqa_train_data_all.jsonl"
IMAGE_DIR = "/home/ubuntu/projects/time_series_main/data/plots"
OUTPUT_DIR = "/home/ubuntu/projects/time_series_main/llama/llama3_captioning_output"
BATCH_SIZE = 4  
GRAD_ACCUM_STEPS = 12
DEBUG_MODE = True  


class TimeSeriesVLMDataset(Dataset):
    def __init__(self, jsonl_path, image_folder, processor):
        self.samples = []
        self.image_folder = Path(image_folder)
        self.processor = processor
        
        logger.info(f"Loading dataset from {jsonl_path}")
        for line in open(jsonl_path):
            try:
                item = json.loads(line)
                image_path = self.image_folder / item["image"]
                if image_path.exists():
                    self.samples.append(item)
                else:
                    logger.warning(f"Image not found: {image_path}, skipping sample")
            except json.JSONDecodeError:
                logger.warning(f"Failed to parse JSON line: {line[:50]}...")
                continue
        
        logger.info(f"Successfully loaded {len(self.samples)} valid samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        image_path = self.image_folder / item["image"]
        
        try:
            image = Image.open(image_path).convert("RGB")
        except (FileNotFoundError, IOError) as e:
            logger.error(f"Error loading image {image_path}: {e}")
            image = Image.new("RGB", (560, 560), color=0)

        messages = []
        for msg in item["conversations"]:
            role = "user" if msg["from"] == "human" else "assistant"
            text = msg["value"]
            if role == "user" and "<image>" not in text:
                text = "<image>\n" + text
            messages.append({"role": role, "content": text})


        prompt = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )

        return {
            "image": image,
            "text": prompt
        }


def process_batch(processor, images, prompts, max_length=512, device=None):

    encoding = processor(
        text=prompts,
        images=images,
        return_tensors="pt",
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_attention_mask=True
    )
    
    batch_size = len(images)
    

    for key, val in encoding.items():
        if isinstance(val, torch.Tensor):
            logger.info(f"{key} shape: {val.shape}")
    
    if "cross_attention_mask" in encoding and 0 in encoding["cross_attention_mask"].shape:
        cross_attn_shape = encoding["cross_attention_mask"].shape
        logger.info(f"Found problematic cross_attention_mask with shape {cross_attn_shape}")
        
        pixel_values = encoding["pixel_values"]
        
        visual_tokens_per_image = 1024

        num_frames = pixel_values.shape[2] 
        

        new_cross_attn_mask = torch.ones(
            (batch_size, max_length, visual_tokens_per_image, num_frames),
            dtype=torch.float32
        )
        
        encoding["cross_attention_mask"] = new_cross_attn_mask
        logger.info(f"Fixed cross_attention_mask shape: {new_cross_attn_mask.shape}")
    
    labels = encoding["input_ids"].clone()
    if processor.tokenizer.pad_token_id is not None:
        padding_mask = labels == processor.tokenizer.pad_token_id
        labels[padding_mask] = -100
        
    encoding["labels"] = labels
    
    if device is not None:
        encoding = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                   for k, v in encoding.items()}
    
    return encoding

def collate_fn(batch):
    images = [item["image"] for item in batch]
    prompts = [item["text"] for item in batch]

    try:
        encoding = process_batch(processor, images, prompts)
        return encoding
    except Exception as e:
        logger.error(f"Error in collate_fn: {e}")
        import traceback
        logger.error(traceback.format_exc())
        raise

def main():
    global processor 
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    logger.info("Loading processor and model...")
    processor = AutoProcessor.from_pretrained(MODEL_NAME, token=TOKEN)

    if processor.tokenizer.pad_token is None:
        processor.tokenizer.pad_token = processor.tokenizer.eos_token
    
    model = MllamaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        token=TOKEN,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True
    )
    
    model = prepare_model_for_kbit_training(model)
    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    
    lora_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )

    model = model.to(DEVICE)
    model = get_peft_model(model, lora_config)
    
    model.print_trainable_parameters()
    logger.info("Loading dataset...")
    full_dataset = TimeSeriesVLMDataset(JSONL_PATH, IMAGE_DIR, processor)
    
    if DEBUG_MODE:
        dataset = Subset(full_dataset, range(min(10, len(full_dataset))))
        logger.info(f"DEBUG MODE: Using {len(dataset)} samples for debugging")
    else:
        dataset = full_dataset
        
    logger.info(f"Dataset loaded with {len(dataset)} samples")

    try:
        sample = dataset[0]
        encoded = process_batch(
            processor, 
            [sample["image"]], 
            [sample["text"]], 
            device=DEVICE  
        )
    
        device_check = {k: (v.device if isinstance(v, torch.Tensor) else "not_tensor") 
                       for k, v in encoded.items()}
        logger.info(f"device check for all tensors: {device_check}")
        
        with torch.no_grad():
            model.eval()
            try:
                outputs = model(**encoded)
                logger.info("forward pass successful with a single sample")
            except Exception as e:
                logger.error(f"forward pass test failed: {e}")
                import traceback
                logger.error(traceback.format_exc())
                raise
            finally:
                model.train()
        
    except Exception as e:
        logger.error(f"Error in sanity check: {e}")
        import traceback
        logger.error(traceback.format_exc())
        raise

    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRAD_ACCUM_STEPS,
        learning_rate=2e-5,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        num_train_epochs=3,
        fp16=True,
        save_strategy="epoch",
        logging_dir=os.path.join(OUTPUT_DIR, "logs"),
        logging_steps=10,
        save_total_limit=2,
        remove_unused_columns=False, 
        report_to="none",
        dataloader_num_workers=0, 
        ddp_find_unused_parameters=False,
        optim="adamw_torch",
        local_rank=-1,
        deepspeed=None,
        no_cuda=False if torch.cuda.is_available() else True,
    )

    def data_collator_with_device(features):
        batch = collate_fn(features)
        return {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v 
                for k, v in batch.items()}

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        processing_class=processor.__class__,  
        data_collator=data_collator_with_device,  
    )

    logger.info("Starting training...")
    try:
        trainer.train()
        trainer.save_model(os.path.join(OUTPUT_DIR, "final_model"))
        logger.info("Training complete")
    except Exception as e:
        logger.error(f"Training failed {e}")
        import traceback
        logger.error(traceback.format_exc())
        raise

if __name__ == "__main__":
    main()
