import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from PIL import Image
from IPython.display import display
import traceback
import random
import numpy as np
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import LlavaForConditionalGeneration, AutoProcessor
import torch
import torch.nn.functional as F
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from tqdm import tqdm
from transformers import TrainingArguments, Trainer, default_data_collator
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from IPython.display import display
from datasets import load_dataset
import nltk
import re
import json
from transformers import TrainerCallback
nltk.download('punkt')

class BestCheckpointTracker(TrainerCallback):
    """Callback to track the best checkpoint based on evaluation loss without saving during evaluation."""
    
    def __init__(self, early_stopping_patience=None):
        self.best_eval_loss = float('inf')
        self.best_checkpoint_step = None
        self.eval_losses = []  # Track all evaluation losses for debugging
        self.early_stopping_patience = early_stopping_patience
        self.steps_since_improvement = 0
        
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Called after evaluation to track the best loss."""
        if metrics is None:
            return
            
        current_eval_loss = metrics.get("eval_loss", float('inf'))
        self.eval_losses.append((state.global_step, current_eval_loss))
        
        # Check if this is the best loss so far
        if current_eval_loss < self.best_eval_loss:
            improvement = self.best_eval_loss - current_eval_loss
            self.best_eval_loss = current_eval_loss
            self.best_checkpoint_step = state.global_step
            self.steps_since_improvement = 0  # Reset counter
            
            print(f"🎉 New best evaluation loss: {current_eval_loss:.4f} at step {state.global_step}")
            print(f"📈 Improvement: {improvement:.4f}")
            print(f"💾 Best checkpoint will be: checkpoint-{state.global_step}")
        else:
            self.steps_since_improvement += args.eval_steps
            print(f"📊 Current loss: {current_eval_loss:.4f} (best: {self.best_eval_loss:.4f} at step {self.best_checkpoint_step})")
            print(f"📉 No improvement for {self.steps_since_improvement} steps")
            
            # Check for early stopping (only if early_stopping_patience is set)
            if self.early_stopping_patience is not None and self.steps_since_improvement >= self.early_stopping_patience:
                print(f"🛑 Early stopping triggered! No improvement for {self.early_stopping_patience} steps.")
                print(f"🏆 Best loss achieved: {self.best_eval_loss:.4f} at step {self.best_checkpoint_step}")
                control.should_training_stop = True  # Signal to stop training
    
    def get_best_checkpoint_info(self):
        """Return information about the best checkpoint."""
        return {
            'best_eval_loss': self.best_eval_loss,
            'best_checkpoint_step': self.best_checkpoint_step,
            'eval_losses': self.eval_losses
        }

class Tee:
    def __init__(self, *streams):
        self.streams = streams

    def write(self, data):
        for s in self.streams:
            s.write(data)
            s.flush()

    def flush(self):
        for s in self.streams:
            s.flush()
        


class CosMixer(nn.Module):
    def __init__(self):
        super().__init__()
        self.similarity_threshold = None
        self.think_mode = None # if think mode is on, the padding tokens are used as pause tokens during training to encourage thinking
        self.enhance_mode = None # if enhance mode is on, the image tokens are copied instead of being moved to the similar text tokens
        self.text_enhance_mode = None # if text enhance mode is on, text tokens are copied next to similar image tokens

    def forward(
        self,
        pixel_values,
        input_ids,
        attention_mask,
        labels,
        tokenizer,
        language_model_embedding,
        vision_tower,
        projector,
        llama_hidden_size=2048,
        use_permutation=True, # TODO adjust this to debug
    ):
        """
        Reorders input tokens and labels based on cosine similarity between 2048-D embeddings.
        Image tokens are moved in front of their most similar text token only if the similarity
        exceeds the specified threshold. Otherwise, they remain in their original position at the
        start of the sequence.
        Args:
            pixel_values: [batch, 3, H, W]
            input_ids: [batch, seq_len]
            attention_mask: [batch, seq_len]
            labels: [batch, seq_len], with -100 for prompt tokens, or None during inference
            tokenizer: Processor’s tokenizer
            language_model_embedding: Language model’s embedding layer
            vision_tower: CLIP vision encoder
            projector: Multimodal projector
            llama_hidden_size: Hidden size of LLaMA (default: 2048)
            use_permutation: Whether to apply token reordering (default: False)
            similarity_threshold: Minimum cosine similarity to move image token (default: 0.6)
        Returns:
            inputs_embeds: [batch, new_seq_len, 2048]
            new_attention_mask: [batch, new_seq_len]
            new_labels: [batch, new_seq_len] (None during inference)
        """
        if not use_permutation:
            bsz = input_ids.size(0)
            # 1. Look up text embeddings
            txt_embeds = language_model_embedding(input_ids)  # [batch, seq_len, hidden_size]

            # 2. Get vision features
            img_hidden = vision_tower(pixel_values, output_hidden_states=True).hidden_states[-2][:, 1:, :]  # [batch, num_patches, vision_hidden_size]
            img_embeds1 = projector(img_hidden).to(dtype=txt_embeds.dtype)  # [batch, num_patches, hidden_size]

            # 3. Get image token ID
            image_token_id = tokenizer.convert_tokens_to_ids("<image>")

            # 4. Create mask for image tokens
            special_image_mask = (input_ids == image_token_id).unsqueeze(-1)  # [batch, seq_len, 1]
            special_image_mask = special_image_mask.expand_as(txt_embeds).to(txt_embeds.device)  # [batch, seq_len, hidden_size]
            n_image_tokens = (input_ids == image_token_id).sum()

            # 5. Validate number of image tokens against image features
            n_image_features = img_embeds1.view(-1).shape[0]  # Total features across batch
            if n_image_tokens * img_embeds1.shape[-1] != n_image_features:
                raise ValueError(
                    f"Image features and image tokens do not match: tokens: {n_image_tokens}, features: {n_image_features // img_embeds.shape[-1]}"
                )

            # 6. Flatten image features to match token positions
            img_embeds_flat = img_embeds1.view(-1, img_embeds1.shape[-1])  # [batch * num_patches, hidden_size]

            # 7. Replace image tokens with image features
            inputs_embeds = txt_embeds.masked_scatter(special_image_mask, img_embeds_flat)
            
            #comparison = torch.all(img_embeds1 == inputs_embeds[:, 1:577])
            #print("Very Beginning!!!! img_embeds == inputs_embeds[:, 1:577]", comparison)

            # 8. Keep attention mask and labels unchanged
            new_attention_mask = attention_mask
            new_labels = labels
            #print(inputs_embeds.shape)
            return inputs_embeds, new_attention_mask, new_labels

            

        bsz = input_ids.size(0)

        # ---------- 1. Obtain 2048-D image patch embeddings (dtype matches LM) ----------
        img_hidden = vision_tower(pixel_values, output_hidden_states=True).hidden_states[-2][:, 1:, :]  # [batch, num_patches, vision_hidden_size]
        img_embeds = projector(img_hidden).to(dtype=language_model_embedding.weight.dtype)  # [B, 576, 2048]

        # img_embeds = img_embeds1 # Use pre-computed image embeddings for DEBUG!!!
        # print("beginning: img_embeds shape:", img_embeds.shape)
        # print("beginning: inputs_embeds shape:", inputs_embeds.shape)
        # # Compare first batch element
        # # Debug: Check if embeddings are identical
        # comparison = torch.all(img_embeds == img_embeds1)
        # print("beginning: img_embeds == img_embeds1", comparison)
        
        # if not comparison:            
        #     # Check max difference
        #     diff = torch.abs(img_embeds - img_embeds1)
        #     print("Max absolute difference:", diff.max().item())
        #     print("Mean absolute difference:", diff.mean().item())


        reordered_seqs, reordered_atts, reordered_labs = [], [], []

        # Pre-compute id of the image sentinel once
        image_token_id = tokenizer.convert_tokens_to_ids("<image>")
        # Pre-compute id of the ASSISTANT separator once (For Inference Only)
        sep_id = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids[-1]
        # print("image_token_id: ", image_token_id)
        # print("sep_id: ", tokenizer("ASSISTANT:", add_special_tokens=False).input_ids)
        
        for b in range(bsz):
            #print(input_ids[b])
            # ---------- 2. Prompt mask (TRAIN vs INFER) ----------
            if labels is not None:  # Training
                if self.think_mode:
                    prompt_mask = (labels[b] == -100)
                    resp_mask = ~prompt_mask
                else:
                    # Fix: Exclude padding by requiring attention_mask == 1
                    prompt_mask = (labels[b] == -100) & (attention_mask[b] == 1)
                    #print("prompt_mask: ", prompt_mask)
                    # Also exclude padding from response mask
                    resp_mask = (labels[b] != -100) & (attention_mask[b] == 1)
                    #print("resp_mask: ", resp_mask)

            else:  # Inference: prompt == everything up to last "ASSISTANT:"
                hits = (input_ids[b] == sep_id).nonzero(as_tuple=True)[0]
                if hits.numel() == 0:
                    # No separator in this prompt → treat the whole sequence as prompt
                    prompt_mask = attention_mask[b].bool()
                else:
                    sep_pos = hits.max()
                    prompt_mask = torch.arange(input_ids.size(1), device=input_ids.device) <= sep_pos
                    #print("prompt_mask: ", prompt_mask) #all true

            prompt_ids = input_ids[b][prompt_mask]
            #print("prompt_ids: ", prompt_ids)
            prompt_text = prompt_ids[prompt_ids.ne(image_token_id)][2:] #excluding 1 and 29871 tokens
            #print("prompt_text: ", prompt_text)

            # ---------- 3. Group image patches by cosine similarity ----------
            if prompt_text.numel() == 0:  # Corner-case: image-only prompt
                groups, num_txt = [[]], 1
                txt_labels = torch.tensor([-100], device=input_ids.device)
                # Keep all image tokens in their original position
                unmoved_images = img_embeds[b]  # All images are unmoved
            else:
                txt_embeds = language_model_embedding(prompt_text.unsqueeze(0)).squeeze(0)  # [T, 2048]
                sim = F.cosine_similarity(
                    txt_embeds.unsqueeze(1),  # [T, 1, 2048]
                    img_embeds[b].unsqueeze(0),  # [1, 576, 2048]
                    dim=2,
                )  # [T, 576]
               
                
                if self.text_enhance_mode:
                    # Text enhance mode: copy text tokens next to similar image tokens
                    # For text enhance: find which image is most similar to each text token
                    assign = sim.argmax(dim=1)  # [T] - which image is most similar to each text token
                    max_sim = sim.max(dim=1).values  # [T] - max similarity for each text token
                    #print("Max cosine similarity scores:",max(max_sim))
                    
                    unmoved_images = []
                    text_groups = [[] for _ in range(img_embeds[b].size(0))]  # Groups for text tokens per image
                    
                    for txt_idx, (img_idx, sim_score) in enumerate(zip(assign, max_sim)):
                        if sim_score >= self.similarity_threshold:
                            # Copy text token next to image token
                            text_groups[img_idx].append(txt_embeds[txt_idx])
                            #print("text enhanced tokens:",txt_idx)
                    
                    # All image tokens stay in original position
                    unmoved_images = img_embeds[b]
                    num_txt = txt_embeds.size(0)
                    txt_labels = torch.full((num_txt,), -100, dtype=torch.long, device=input_ids.device)
                else:
                    # Original logic for enhance_mode and reordering
                    #print("cosine similarity scores:",sim)
                    assign = sim.argmax(dim=0)  # [576]
                    max_sim = sim.max(dim=0).values  # [576]
                    #print("Max cosine similarity scores:",max(max_sim))
                    
                    groups = [[] for _ in range(txt_embeds.size(0))]
                    if self.enhance_mode:
                        # Keep all image tokens in their original position
                        unmoved_images = img_embeds[b]  # [576, 2048]
                        # print("unmoved_images == inputs_embeds[b][1:577]", unmoved_images == inputs_embeds[b][1:577])
                        # print("unmoved_images == inputs_embeds[b][1:577]",torch.all(unmoved_images == inputs_embeds[b][1:577]))
                    else:
                        unmoved_images = []
                    for img_idx, (txt_idx, sim_score) in enumerate(zip(assign, max_sim)):
                        if sim_score >= self.similarity_threshold:
                            groups[txt_idx].append(img_embeds[b, img_idx])
                        elif not self.enhance_mode:
                            unmoved_images.append(img_embeds[b, img_idx])
                    if not self.enhance_mode:
                        unmoved_images = torch.stack(unmoved_images) if unmoved_images else torch.tensor([], device=input_ids.device, dtype=img_embeds.dtype)
                    num_txt = txt_embeds.size(0)
                    txt_labels = torch.full((num_txt,), -100, dtype=torch.long, device=input_ids.device)

            reordered, lab = [], []

            #print("groups", groups)

            # Add BOS token (if present)
            bos_token_id = 1  # BOS token (<s>)
            if (prompt_ids == bos_token_id).any():
                bos_embed = language_model_embedding(torch.tensor([bos_token_id], device=input_ids.device)).squeeze(0)  # [2048]
                reordered.append(bos_embed)
                lab.append(-100)
     

            
            # Add grouped image tokens (or copies) and text tokens
            if self.text_enhance_mode:
                # Text enhance mode: add image tokens with their copied text tokens
                for img_idx in range(len(unmoved_images)):
                    # Add image token
                    reordered.append(unmoved_images[img_idx])
                    lab.append(-100)
                    # Add copied text tokens next to this image token
                    if text_groups[img_idx]:
                        reordered.extend(text_groups[img_idx])
                        lab.extend([-100] * len(text_groups[img_idx]))

                # Add space token (if present)
                space_token_id = 29871  # space token
                if (prompt_ids == space_token_id).any():
                    reordered.append(language_model_embedding(torch.tensor([space_token_id], device=input_ids.device)).squeeze(0))
                    lab.append(-100)

                # Add prompt text tokens at the end
                for t_idx in range(num_txt):
                    reordered.append(txt_embeds[t_idx])
                    lab.append(txt_labels[t_idx])
            else:
                # Add unassigned or the original image tokens at the start
                if unmoved_images.numel() > 0:
                    reordered.extend(unmoved_images)
                    # Convert list to tensor for comparison
                    #reordered_tensor = torch.stack(reordered) if reordered else torch.tensor([], device=input_ids.device)
                    #print("reordered == inputs_embeds[b][:577]", torch.all(reordered_tensor == inputs_embeds[b][:len(reordered_tensor)]))
                    lab.extend([-100] * len(unmoved_images))

                # Add space token (if present)
                space_token_id = 29871  # space token
                if (prompt_ids == space_token_id).any():
                    reordered.append(language_model_embedding(torch.tensor([space_token_id], device=input_ids.device)).squeeze(0))
                    lab.append(-100)

                #print("reordered", reordered)
                #print("reordered", language_model_embedding(torch.tensor([input_ids[b][0]], device=input_ids.device) ) == reordered[0])
                #print("reordered", language_model_embedding(torch.tensor([input_ids[b][577] ], device=input_ids.device) ) == reordered[577])


                # Original logic for enhance_mode and reordering
                for t_idx in range(num_txt):
                    # Add moved or copied image tokens for this text token
                    if groups[t_idx]:  # Only add if there are image tokens assigned
                        reordered.extend(groups[t_idx])
                        lab.extend([-100] * len(groups[t_idx]))
                    if t_idx < len(txt_embeds):
                        reordered.append(txt_embeds[t_idx])
                        lab.append(txt_labels[t_idx])

            if labels is not None:  # Training: add answer tokens
                reordered.extend(language_model_embedding(input_ids[b])[resp_mask])
                lab.extend(labels[b][resp_mask].tolist())

            reordered_seqs.append(torch.stack(reordered) if reordered else torch.tensor([], device=input_ids.device, dtype=img_embeds.dtype))
            reordered_atts.append(torch.ones(len(reordered), device=input_ids.device, dtype=torch.long))
            reordered_labs.append(torch.tensor(lab, device=input_ids.device, dtype=torch.long))

        # ---------- 5. Pad to equal length ----------
        max_len = max(seq.size(0) for seq in reordered_seqs) if reordered_seqs else 1
        dtype = img_embeds.dtype
        pad_seq = torch.zeros(bsz, max_len, llama_hidden_size, device=input_ids.device, dtype=dtype)
        pad_att = torch.zeros(bsz, max_len, device=input_ids.device, dtype=torch.long)
        pad_lab = None if labels is None else torch.full(
            (bsz, max_len), -100, device=input_ids.device, dtype=torch.long
        )

        for b, (seq, att, lab) in enumerate(zip(reordered_seqs, reordered_atts, reordered_labs)):
            if seq.numel() > 0:  # Handle empty sequences
                pad_seq[b, :seq.size(0)] = seq
                #print("pad_seq[b], len(seq)", len(pad_seq[b]), len(seq))
                pad_att[b, :att.size(0)] = att
                #print("pad_att[b], len(att)", len(pad_att[b]), len(att))
                sl = seq.size(0)
                pad_token_id = tokenizer.pad_token_id 
                pad_embed_single = language_model_embedding(torch.tensor([pad_token_id], device=input_ids.device)).squeeze(0)
                if sl < max_len:
                    pad_embed = pad_embed_single.repeat(max_len - sl, 1)  # [pad_len, hidden_size]
                    pad_seq[b, sl:] = pad_embed
                    #pad_seq[b, sl:] = language_model_embedding(input_ids[b, sl:])
                if labels is not None:
                    pad_lab[b, :lab.size(0)] = lab

        # torch.set_printoptions(threshold=10000)
        # # Per-token equality: True only if all hidden dims match for that token
        # token_equal = (pad_seq == inputs_embeds).all(dim=2)
        # #print("per-token equal (bsz x seq_len):", token_equal)
        # print("all tokens equal:", bool(token_equal.all().item()))
        # print("pad_seq==inputs_embeds",torch.all(pad_seq == inputs_embeds))
        # mismatched = (~token_equal).nonzero(as_tuple=False)
        # #if mismatched.numel() > 0:
        #     #print("mismatched token coords:", mismatched.tolist())
        # print("pad_att shape, new_attention_mask shape",pad_att.shape, new_attention_mask.shape)
        # print("pad_att==new_attention_mask",torch.all(pad_att == new_attention_mask))
        # print("pad_lab shape, new_labels shape",pad_lab.shape,new_labels.shape)
        # print("pad_lab==new_labels",torch.all(pad_lab == new_labels))

        return pad_seq, pad_att, pad_lab





def get_target_models(model):
    patterns = [
      re.compile(r".*language_model\.layers\.\d+\.self_attn\.(q|k|v|o)_proj$"),
      re.compile(r".*language_model\.layers\.\d+\.mlp\.(gate|up|down)_proj$"),
      re.compile(r".*multi_modal_projector\.linear_(1|2)$"),  # Add projector modules

    ]

    matched = []
    for name, module in model.named_modules():
        if any(p.search(name) for p in patterns):
            matched.append(name)
    return matched

def check_trainable_parameters(model):
    print("Checking trainable parameters for permutation_net:")
    for name, param in model.named_parameters():
        print(f"{name}: requires_grad={param.requires_grad}")



class CustomLlavaForConditionalGeneration(LlavaForConditionalGeneration):
    def __init__(self, config, token_mixer_class = CosMixer):
        super().__init__(config)
        self.token_mixer = token_mixer_class()
        self.processor = None

        # print("initializing CustomLlavaForConditionalGeneration")
        # print("*"*100)
    
    def load_token_mixer_config(self, path, message = True):
        if os.path.isdir(path):
            json_path = os.path.join(path, "config.json")
        else:
            json_path = path

        if not os.path.isfile(json_path):
            print('!'*20, f"\nCustomLlavaForConditionalGeneration load fail load_token_mixer_config: file not found: {json_path}")
            return 

        with open(json_path, "r", encoding="utf-8") as f:
            config_dict = json.load(f)
            self.token_mixer.similarity_threshold = config_dict['similarity_threshold']
            self.token_mixer.think_mode = config_dict['think_mode']
            self.token_mixer.enhance_mode = config_dict['enhance_mode']
            self.token_mixer.text_enhance_mode = config_dict.get('text_enhance_mode', False)
        
        if message:
            print(f"CustomLlavaForConditionalGeneration loaded with thres={self.token_mixer.similarity_threshold}",
                f", think_mode={self.token_mixer.think_mode}",
                f", enhance_mode={self.token_mixer.enhance_mode}",
                f", text_enhance_mode={self.token_mixer.text_enhance_mode}")
    
    def set_token_mixer_processor(self, processor):
        self.processor = processor

    def forward(self, input_ids=None, pixel_values=None, attention_mask=None, labels=None, inputs_embeds=None, processor=None, **kwargs):
        if inputs_embeds is not None: # for testing: inputs_embeds is given
            filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ["input_ids", "pixel_values", "inputs_embeds", "labels"]}
            return super().forward(
                inputs_embeds=inputs_embeds,
                attention_mask=attention_mask,
                labels=labels,
                **filtered_kwargs
            )
        else: #for training mixer: inputs_embeds is not given/ for text-only prompt testing: 
            if pixel_values is None:
                filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ["input_ids", "pixel_values", "inputs_embeds", "labels"]}
                #print(f"input_ids shape (text-only): {input_ids.shape}")
                outputs = super().forward(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    **filtered_kwargs
                )
                if outputs.logits.dim() > 3:
                    outputs.logits = outputs.logits.squeeze(1)
                #print(f"logits shape (text-only): {outputs.logits.shape}")
                return outputs

            else:
                embedding_layer = self.get_input_embeddings()
                inputs_embeds, new_attention_mask, new_labels = self.token_mixer(
                    pixel_values, input_ids, attention_mask, labels,
                    self.processor.tokenizer, embedding_layer, self.vision_tower, self.multi_modal_projector
                )
                filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ["input_ids", "pixel_values", "inputs_embeds", "labels"]}
                return super().forward(
                    inputs_embeds=inputs_embeds,
                    attention_mask=new_attention_mask,
                    labels=new_labels,
                    **filtered_kwargs
                )

    def generate(
        self,
        processor = None,
        input_ids=None,
        pixel_values=None,
        attention_mask=None,
        max_new_tokens=128,
        do_sample=False,
        temperature=1.0,
        top_p=1.0,
        **kwargs
    ):
        """
        Custom generate method for multimodal text generation.

        Args:
            input_ids (torch.LongTensor): Input token IDs.
            pixel_values (torch.FloatTensor, optional): Image pixel values.
            attention_mask (torch.LongTensor, optional): Attention mask for input_ids.
            max_new_tokens (int): Maximum number of new tokens to generate.
            do_sample (bool): Whether to use sampling (True) or greedy decoding (False).
            temperature (float): Temperature for sampling.
            top_p (float): Top-p probability for nucleus sampling.
            **kwargs: Additional arguments passed to the model.

        Returns:
            torch.LongTensor: Generated token IDs.
        """
        self.eval()  # Set model to evaluation mode

        #print("pixel_values", pixel_values, flush=True)

        if pixel_values is None:
            # Handle text-only inputs by falling back to parent class
            gen_ids = super().generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                pixel_values=None,
                max_new_tokens=max_new_tokens,           
                do_sample=do_sample,            
                temperature=temperature,
                top_p=top_p,
                **kwargs
            )

            # Return only the generated tokens (excluding input prompt)
            # Handle both single sample and batch cases
            if gen_ids.dim() == 2:  # Batch case
                #print("input_ids.shape", input_ids.shape)
                return gen_ids[:, input_ids.shape[1]:]
            else:  # Single sample case
                #print("input_ids.shape", input_ids.shape)
                return gen_ids[input_ids.shape[1]:]

        else:
            # Process inputs using token_mixer
            inputs_embeds, attention_mask, _ = self.token_mixer(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=None,
                tokenizer=processor.tokenizer,  # Ensure tokenizer is accessible
                language_model_embedding=self.get_input_embeddings(),
                vision_tower=self.vision_tower,
                projector=self.multi_modal_projector,
                llama_hidden_size=2048
            )

            # Initialize generation
            generated_tokens = []
            current_embeds = inputs_embeds
            current_attention_mask = attention_mask
            past_key_values = None

            # Autoregressive generation loop
            with torch.no_grad():
                for step in range(max_new_tokens):
                    # Forward pass
                    outputs = self(
                        inputs_embeds=current_embeds[:, -1:, :] if step > 0 else current_embeds,#prefill -> decode
                        attention_mask=current_attention_mask,
                        past_key_values=past_key_values,
                        processor=processor,
                        **kwargs
                    )
                    logits = outputs.logits
                    past_key_values = outputs.past_key_values

                    # Get logits for the last token
                    next_token_logits = logits[:, -1, :]

                    # Sampling or greedy decoding
                    if do_sample:
                        scaled_logits = next_token_logits / temperature
                        probs = F.softmax(scaled_logits, dim=-1)
                        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                        sorted_indices_to_remove = cumulative_probs > top_p
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = False
                        probs[sorted_indices_to_remove] = 0
                        probs = probs / probs.sum(dim=-1, keepdim=True)
                        next_token = torch.multinomial(probs, num_samples=1)
                    else:
                        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

                    # Append generated token
                    generated_tokens.append(next_token.item())
                    if next_token.item() == processor.tokenizer.eos_token_id:
                        break

                    # Update inputs for next step
                    next_token_embed = self.get_input_embeddings()(next_token)
                    current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)
                    current_attention_mask = torch.cat(
                        [current_attention_mask, torch.ones_like(next_token, dtype=torch.long, device=self.device)],
                        dim=1
                    )

            # Return generated token IDs
            return torch.tensor(generated_tokens, dtype=torch.long, device=self.device)

        
def pixel_values_to_pil(pixel_values, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    pixel_values = pixel_values.clone().detach().cpu()  # Ensure on CPU
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    pixel_values = pixel_values * std + mean  # Denormalize
    pixel_values = pixel_values.clamp(0, 1)  # Ensure [0, 1]
    pixel_values = pixel_values.permute(1, 2, 0).numpy()  # [C, H, W] -> [H, W, C]
    pixel_values = (pixel_values * 255).astype(np.uint8)  # Scale to [0, 255]
    return Image.fromarray(pixel_values)


import traceback
from PIL import Image
from IPython.display import display
import torch
import aiohttp
import asyncio
from io import BytesIO

IGNORE = -100  # Loss is not computed on these tokens

async def fetch_image(session, url):
    """Fetch a single image asynchronously."""
    try:
        async with session.get(url) as response:
            if response.status != 200:
                raise ValueError(f"Failed to download image from {url}")
            content = await response.read()
            return Image.open(BytesIO(content)).convert("RGB")
    except Exception as e:
        raise ValueError(f"Image fetch failed: {e}")

def calculate_dataset_sizes(dataset_size, split_ratio, metric_data_percetage = 1):
    train_ds_num = int(split_ratio*dataset_size)
    test_ds_num = int((1-split_ratio)*dataset_size)
    metric_ds_num = int(test_ds_num*metric_data_percetage)
    return train_ds_num, test_ds_num, metric_ds_num


def unwrap_150k_row(row, image_dir = "../coco/train2017/"):
    user_msg = next(m["value"] for m in row["conversations"] if m["from"] == "human")
    asst_msg = next(m["value"] for m in row["conversations"] if m["from"] == "gpt")
    num_image_tokens = user_msg.count("<image>")
    if num_image_tokens != 1:
        raise ValueError(f"Skipping row with {num_image_tokens} <image> tokens")

    user_text = '<image>\n' + re.sub(r'\n*<image>\n*', '', user_msg).strip()
    # print(user_text, type(user_text))

    asst_text = asst_msg.strip()
    if not user_text or not asst_text:
        raise ValueError("Empty user or assistant text")

    # Load image from local directory 
    image_filename = row["image"]  # e.g., '000000168718.jpg'
    image_path = os.path.join(image_dir, image_filename)

    if not os.path.exists(image_path):
        raise ValueError(f"Image not found at {image_path}")

    image = Image.open(image_path).convert("RGB")

    return user_text, asst_text, image



def load_llava_instruct_150k_local(ds, max_samples=200000, num_samples_to_show=3, processor=None, image_dir="coco/train2017/"):

    print("Loading llava 150k")

    processed, skips, samples_to_show = [], 0, []
    IGNORE = -100  # Constant for ignoring tokens in labels

    for idx, row in enumerate(tqdm(ds)):
        if len(processed) >= max_samples:
            break

        try:
            user_text, asst_text, image = unwrap_150k_row(row, image_dir)

            # Build full conversation
            prompt = f"{user_text}\nASSISTANT:"
            convo = f"{prompt} {asst_text} {processor.tokenizer.eos_token}"

            # Tokenize conversation
            enc = processor(
                text=convo,
                images=image,
                return_tensors="pt",
                truncation=True,
                padding=False,
            )

            # ------------- 4. Make labels -----------------------
            input_ids = enc.input_ids[0]
            attention_mask = enc.attention_mask[0]
            pixel_values = enc.pixel_values[0]

            prefix_len = len(
                processor(
                    text=prompt,
                    images=image,
                    return_tensors="pt",
                    add_special_tokens=False
                ).input_ids[0]
            )

            labels = input_ids.clone()
            labels[:prefix_len] = IGNORE

            # ------------- 5. Store -----------------------------
            processed.append({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "pixel_values": pixel_values,
                "labels": labels,
            })

            # Store samples for display
            if len(samples_to_show) < num_samples_to_show:
                samples_to_show.append(
                    dict(prompt=prompt, assistant_text=asst_text, image=image)
                )

        except Exception as e:
            if skips < 10:
                print(f"⚠️ Skipped row {idx}: {e}")
            skips += 1
            continue

    # ------------ 6. Report & show -----------------------------
    print(f"✅ Prepared {len(processed)} examples, skipped {skips} rows\n")
    for i, s in enumerate(samples_to_show, 1):
        print(f"Sample {i}\nPrompt: {s['prompt']} {s['assistant_text']}")
        s['image'].show()  # Display the image

    return processed

    



async def load_llava_instruct_150k(max_samples=8000, seed=42, num_samples_to_show=3, processor=None, batch_size=10):
    """
    Load and process the LLaVA-Instruct-150K dataset in streaming mode with async image downloads.

    Args:
        max_samples (int): Maximum number of samples to process.
        seed (int): Random seed for shuffling.
        num_samples_to_show (int): Number of samples to display for verification.
        processor: Tokenizer/processor for text and images (assumed to have tokenizer and image processing).
        batch_size (int): Number of images to download concurrently.

    Returns:
        List of processed examples, each containing input_ids, attention_mask, pixel_values, and labels.
    """
    from datasets import load_dataset

    # Load dataset in streaming mode
    ds = load_dataset(
        "liuhaotian/LLaVA-Instruct-150K",
        split="train",
        streaming=True
    ).shuffle(seed=seed, buffer_size=1000)

    print("Dataset loaded in streaming mode")

    processed, skips, samples_to_show = [], 0, []
    base_url = "http://images.cocodataset.org/train2017/"  # Adjust as needed

    async with aiohttp.ClientSession() as session:
        batch = []
        for idx, row in enumerate(tqdm(ds)):
            if len(processed) >= max_samples:
                break

            try:
                # ------------- 1. Pull texts ----------------
                user_msg = next(m["value"] for m in row["conversations"] if m["from"] == "human")
                asst_msg = next(m["value"] for m in row["conversations"] if m["from"] == "gpt")
                # print(row)
                # print("", )
                user_text = user_msg.strip()
                num_image_tokens = user_text.count("<image>")
                if num_image_tokens != 1:
                    raise ValueError(f"Skipping row with {num_image_tokens} <image> tokens")

                asst_text = asst_msg.strip()
                if not user_text or not asst_text:
                    raise ValueError("Empty user or assistant text")

                # Prepare image URL
                image_filename = row["image"]  # e.g., '000000168718.jpg'
                image_url = f"{base_url}{image_filename}"

                # Collect row data for batch processing
                batch.append((idx, user_text, asst_text, image_url))

                # Process batch when full or at the end
                if len(batch) >= batch_size or len(processed) + len(batch) >= max_samples:
                    # Fetch images concurrently
                    tasks = [fetch_image(session, url) for _, _, _, url in batch]
                    images = await asyncio.gather(*tasks, return_exceptions=True)

                    for (idx, user_text, asst_text, image_url), image in zip(batch, images):
                        try:
                            if isinstance(image, Exception):
                                raise ValueError(f"Image download failed: {image}")

                            # ------------- 2. Build full conversation -----------
                            prompt = f"{user_text}\nASSISTANT:"
                            convo = f"{prompt} {asst_text} {processor.tokenizer.eos_token}"

                            # Tokenize conversation
                            enc = processor(
                                text=convo,
                                images=image,
                                return_tensors="pt",
                                truncation=True,
                                padding=False,
                            )

                            # ------------- 3. Make labels -----------------------
                            input_ids = enc.input_ids[0]
                            attention_mask = enc.attention_mask[0]
                            pixel_values = enc.pixel_values[0]

                            prefix_len = len(
                                processor(
                                    text=prompt,
                                    images=image,
                                    return_tensors="pt",
                                    add_special_tokens=False
                                ).input_ids[0]
                            )

                            labels = input_ids.clone()
                            labels[:prefix_len] = IGNORE

                            # ------------- 4. Store -----------------------------
                            processed.append({
                                "input_ids": input_ids,
                                "attention_mask": attention_mask,
                                "pixel_values": pixel_values,
                                "labels": labels,
                            })

                            # Store samples for display
                            if len(samples_to_show) < num_samples_to_show:
                                samples_to_show.append(
                                    dict(prompt=prompt, assistant_text=asst_text, image=image)
                                )

                        except Exception as e:
                            if skips < 10:
                                print(f"⚠️ Skipped row {idx}: {e}")
                                traceback.print_exc()
                            skips += 1
                            continue

                    batch = []  # Clear batch

            except Exception as e:
                if skips < 10:
                    print(f"⚠️ Skipped row {idx}: {e}")
                    traceback.print_exc()
                skips += 1
                continue

    # ------------ 5. Report & show -----------------------------
    print(f"✅ Prepared {len(processed)} examples, skipped {skips} rows\n")
    for i, s in enumerate(samples_to_show, 1):
        print(f"Sample {i}\nPrompt: {s['prompt']} {s['assistant_text']}")
        display(s['image'])

    return processed



class CustomDataCollator:
    def __init__(self, pad_token_id=0, ignore_index=-100):
        self.pad_token_id = pad_token_id
        self.ignore_index = ignore_index

    def __call__(self, features):
        # ── 1. sanity check ──────────────────────────────────────────────────────
        req = {"input_ids", "attention_mask", "pixel_values", "labels"}
        for i, f in enumerate(features):
            if not req.issubset(f):
                raise ValueError(f"Feature {i} missing keys {req - f.keys()}")

        # ── 2. compute max length once ──────────────────────────────────────────
        max_len = max(f["input_ids"].size(0) for f in features)

        input_ids, attn_masks, labels, pixels = [], [], [], []
        for f in features:
            L   = f["input_ids"].size(0)
            pad = max_len - L

            # right-pad inputs and masks
            input_ids.append(torch.cat([f["input_ids"],
                                        f["input_ids"].new_full((pad,), self.pad_token_id)]))
            attn_masks.append(torch.cat([f["attention_mask"],
                                         f["attention_mask"].new_zeros(pad)]))

            # right-pad labels the same way
            labels.append(torch.cat([f["labels"],
                                     f["labels"].new_full((pad,), self.ignore_index)]))

            pixels.append(f["pixel_values"])

        return {
            "input_ids":      torch.stack(input_ids),
            "attention_mask": torch.stack(attn_masks),
            "labels":         torch.stack(labels),
            "pixel_values":   torch.stack(pixels),
        }
    


def compute_bleu(model, processor, val_data, max_new_tokens=64, do_sample=False, temperature=0.7, top_p=0.9):
    model.eval()
    predictions, references = [], []

    for sample in tqdm(val_data, desc="Computing BLEU"):

        processor.patch_size = 14

        question, gt_answer, img = unwrap_150k_row(sample)
        # print(gt_answer, type(gt_answer))

        question = question + "\nASSISTANT:"

        # Prepare input
        inputs = processor(
            text=question,
            images=img,
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        generated_tokens = model.generate(
            processor = processor,
            pixel_values=inputs["pixel_values"],
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens
        )

        # Decode generated tokens
        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()

        # Store prediction and reference
        predictions.append(generated_text)
        references.append([gt_answer.lower().strip()])

        # Debug: Print generated vs. reference text
        #print(f"Generated: {generated_text} | Reference: {gt_answer}")

    # Compute BLEU with NLTK
    bleu_scores = []
    for pred, ref in zip(predictions, references):
        score = sentence_bleu(
            references=[r.split() for r in ref],
            hypothesis=pred.split(),
            weights=(0.5, 0.5),  # Unigrams and bigrams, equal weights
            smoothing_function=SmoothingFunction().method1
        )
        bleu_scores.append(score)

    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    print(f"Current validation BLEU Score: {avg_bleu:.4f}")
    return avg_bleu




class CustomTrainer(Trainer):
    def __init__(self, *args, val_data=None, processor=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.val_data = val_data
        self.processor = processor

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        # Call the parent evaluate method to compute standard metrics (e.g., loss)
        metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)

        # Compute BLEU score on validation data
        # if self.val_data is not None:
        #     bleu_score = compute_bleu(
        #         model=self.model,
        #         processor=self.processor,
        #         val_data=self.val_data,
        #         max_new_tokens=128,
        #         do_sample=False
        #     )
        #     metrics[f"{metric_key_prefix}_bleu"] = bleu_score
        #     self.log(metrics)

        return metrics

        

from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
from IPython.display import display
from datasets import load_dataset

def compare_infer_results(
    ds,
    base_model,
    base_processor,
    fine_tuned_model,
    fine_tuned_processor,
    idxes=[0],
    max_new_tokens=64,
    do_sample=False,
    temperature=0.7,
    top_p=0.9
):
    base_model.eval()
    fine_tuned_model.eval()

    for idx in idxes:
        input_ids = ds[idx]["input_ids"]  # Tokenized conversation (prompt + answer + EOS)
        attention_mask = ds[idx]["attention_mask"]
        pixel_values = ds[idx]["pixel_values"]  # Preprocessed image tensor
        labels = ds[idx]["labels"]  # Token IDs with prompt masked

        # Set processor patch size (as in original code)
        base_processor.patch_size = 14

        # ------------- 1. Extract question and ground-truth answer ----------------
        # Decode input_ids to reconstruct the full conversation
        full_convo = base_processor.tokenizer.decode(input_ids, skip_special_tokens=False)

        # Split into prompt and answer based on "ASSISTANT:" marker
        # From load_llava_instruct_150k, prompt = user_text + "\nASSISTANT:"
        try:
            prompt, gt_answer = full_convo.split("ASSISTANT:", 1)
            prompt = prompt.strip()
            gt_answer = gt_answer.strip()
            # Remove EOS token from gt_answer if present
            if gt_answer.endswith(base_processor.tokenizer.eos_token):
                gt_answer = gt_answer[: -len(base_processor.tokenizer.eos_token)].strip()
        except ValueError:
            print(f"Skipping sample: Could not split conversation: {full_convo}")
            continue

        # Verify prompt contains <image> placeholder
        if "<image>" not in prompt:
            print(f"Skipping sample: No <image> placeholder in prompt: {prompt}")
            continue

        # Alternative: Extract gt_answer from labels (unmasked tokens)
        # label_tokens = [t for t in labels if t != IGNORE]
        # gt_answer_alt = processor.tokenizer.decode(label_tokens, skip_special_tokens=True).strip()

        # ------------- 2. Prepare input for model ----------------
        # Reuse preprocessed inputs if compatible with model
        inputs_base = {
            "input_ids": input_ids.unsqueeze(0),  # Add batch dimension
            "attention_mask": attention_mask.unsqueeze(0),
            "pixel_values": pixel_values.unsqueeze(0),
        }

        # Move inputs to model device
        inputs_base = {k: v.to(base_model.device) for k, v in inputs_base.items()}


        with torch.no_grad():
            inputs_embeds_base, new_attention_mask_base, _ = base_model.token_mixer(
                pixel_values=inputs_base["pixel_values"],
                input_ids=inputs_base["input_ids"],
                attention_mask=inputs_base["attention_mask"],
                labels=None,
                tokenizer=base_processor.tokenizer,
                language_model_embedding=base_model.get_input_embeddings(),
                vision_tower=base_model.vision_tower,
                projector=base_model.multi_modal_projector,
                llama_hidden_size=2048
            )

        generated_tokens_base = []
        current_embeds_base = inputs_embeds_base
        current_attention_mask_base = new_attention_mask_base
        past_key_values_base = None

        for step in range(max_new_tokens):
            with torch.no_grad():
                outputs_base = base_model(
                    inputs_embeds=current_embeds_base[:, -1:, :] if step > 0 else current_embeds_base,
                    attention_mask=current_attention_mask_base,
                    past_key_values=past_key_values_base,
                    use_cache=True
                )
                logits_base = outputs_base.logits
                past_key_values_base = outputs_base.past_key_values

            next_token_logits_base = logits_base[:, -1, :]
            if do_sample:
                scaled_logits = next_token_logits_base / temperature
                probs = F.softmax(scaled_logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = False
                probs[sorted_indices_to_remove] = 0
                probs = probs / probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits_base, dim=-1, keepdim=True)

            generated_tokens_base.append(next_token.item())
            if next_token.item() == base_processor.tokenizer.eos_token_id:
                break

            next_token_embed = base_model.get_input_embeddings()(next_token)
            current_embeds_base = torch.cat([current_embeds_base, next_token_embed], dim=1)
            current_attention_mask_base = torch.cat(
                [current_attention_mask_base, torch.ones_like(next_token, dtype=torch.long, device=base_model.device)],
                dim=1
            )

        answer_base = base_processor.tokenizer.decode(
            generated_tokens_base, skip_special_tokens=True
        ).strip()

        # --- Fine-Tuned Model Inference ---

        # Set processor patch size (as in original code)
        fine_tuned_processor.patch_size = 14

        # ------------- 1. Extract question and ground-truth answer ----------------
        # Decode input_ids to reconstruct the full conversation
        full_convo = fine_tuned_processor.tokenizer.decode(input_ids, skip_special_tokens=False)

        # Split into prompt and answer based on "ASSISTANT:" marker
        # From load_llava_instruct_150k, prompt = user_text + "\nASSISTANT:"
        try:
            prompt, gt_answer = full_convo.split("ASSISTANT:", 1)
            prompt = prompt.strip()
            gt_answer = gt_answer.strip()
            # Remove EOS token from gt_answer if present
            if gt_answer.endswith(fine_tuned_processor.tokenizer.eos_token):
                gt_answer = gt_answer[: -len(fine_tuned_processor.tokenizer.eos_token)].strip()
        except ValueError:
            print(f"Skipping sample: Could not split conversation: {full_convo}")
            continue

        # Verify prompt contains <image> placeholder
        if "<image>" not in prompt:
            print(f"Skipping sample: No <image> placeholder in prompt: {prompt}")
            continue

        # Alternative: Extract gt_answer from labels (unmasked tokens)
        # label_tokens = [t for t in labels if t != IGNORE]
        # gt_answer_alt = processor.tokenizer.decode(label_tokens, skip_special_tokens=True).strip()

        # ------------- 2. Prepare input for model ----------------
        # Reuse preprocessed inputs if compatible with model
        inputs_ft = {
            "input_ids": input_ids.unsqueeze(0),  # Add batch dimension
            "attention_mask": attention_mask.unsqueeze(0),
            "pixel_values": pixel_values.unsqueeze(0),
        }

        # Move inputs to model device
        inputs_ft = {k: v.to(fine_tuned_model.device) for k, v in inputs_ft.items()}

        with torch.no_grad():
            inputs_embeds_ft, new_attention_mask_ft, _ = fine_tuned_model.token_mixer(
                pixel_values=inputs_ft["pixel_values"],
                input_ids=inputs_ft["input_ids"],
                attention_mask=inputs_ft["attention_mask"],
                labels=None,
                tokenizer=fine_tuned_processor.tokenizer,
                language_model_embedding=fine_tuned_model.get_input_embeddings(),
                vision_tower=fine_tuned_model.vision_tower,
                projector=fine_tuned_model.multi_modal_projector,
                llama_hidden_size=2048
            )

        generated_tokens_ft = []
        current_embeds_ft = inputs_embeds_ft
        current_attention_mask_ft = new_attention_mask_ft
        past_key_values_ft = None

        for step in range(max_new_tokens):
            with torch.no_grad():
                outputs_ft = fine_tuned_model(
                    inputs_embeds=current_embeds_ft[:, -1:, :] if step > 0 else current_embeds_ft,
                    attention_mask=current_attention_mask_ft,
                    past_key_values=past_key_values_ft,
                    use_cache=True
                )
                logits_ft = outputs_ft.logits
                past_key_values_ft = outputs_ft.past_key_values

            next_token_logits_ft = logits_ft[:, -1, :]
            if do_sample:
                scaled_logits = next_token_logits_ft / temperature
                probs = F.softmax(scaled_logits, dim=-1)
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = False
                probs[sorted_indices_to_remove] = 0
                probs = probs / probs.sum(dim=-1, keepdim=True)
                next_token = torch.multinomial(probs, num_samples=1)
            else:
                next_token = torch.argmax(next_token_logits_ft, dim=-1, keepdim=True)

            generated_tokens_ft.append(next_token.item())
            if next_token.item() == fine_tuned_processor.tokenizer.eos_token_id:
                break

            next_token_embed = fine_tuned_model.get_input_embeddings()(next_token)
            current_embeds_ft = torch.cat([current_embeds_ft, next_token_embed], dim=1)
            current_attention_mask_ft = torch.cat(
                [current_attention_mask_ft, torch.ones_like(next_token, dtype=torch.long, device=fine_tuned_model.device)],
                dim=1
            )

        answer_ft = fine_tuned_processor.tokenizer.decode(
            generated_tokens_ft, skip_special_tokens=True
        ).strip()

        # --- Display Results ---
        print("═" * 80)
        print(f"Sample {idx}")
        print("Prompt:", prompt.replace("<image>", ""))
        print("\nGround-truth:", gt_answer)
        print("Model Answer Before LORA:", answer_base or "<empty>")
        print("Fine-Tuned Model Answer:", answer_ft or "<empty>")
        display(pixel_values_to_pil(pixel_values))


