import argparse
import os
import sys

# Add the project root to the path so we can import internvl modules
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer, AutoConfig

# Import the model classes
# Note: We import them with aliases to distinguish between the two use cases
try:
    # Try importing from the package __init__ first
    from internvl.model.detached_skiplink import PeVLChatModel as VLMModel
    from internvl.model.detached_skiplink import PeVLChatConfig as VLMConfig
    from internvl.model.r_probe import PeVLChatModel as ReconModel
    from internvl.model.r_probe import PeVLChatConfig as ReconConfig
except ImportError:
    # Fallback to direct file import if the __init__ is inconsistent with filenames
    from internvl.model.detached_skiplink.modeling_llama_pe_chat import PeVLChatModel as VLMModel
    from internvl.model.detached_skiplink.configuration_llama_pe_chat import PeVLChatConfig as VLMConfig
    from internvl.model.r_probe.modeling_llama_pe_chat import PeVLChatModel as ReconModel
    from internvl.model.r_probe.configuration_llama_pe_chat import PeVLChatConfig as ReconConfig
from internvl.train.dataset import build_transform, dynamic_preprocess

def load_image(image_path, input_size=448, max_num=6):
    image = Image.open(image_path).convert('RGB')
    transform = build_transform(is_train=False, input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def run_vlm_inference(args):
    """
    Run inference using the PE+Llama model.
    """
    print(f"Loading VLM model from {args.model_path}...")
    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    model = VLMModel.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    ).eval().cuda()

    pixel_values = load_image(args.image_path).to(torch.bfloat16).cuda()
    generation_config = dict(
        max_new_tokens=1024,
        do_sample=False,
    )

    # Simple single-turn conversation template
    question = args.prompt
    # Note: Adjust the template processing based on your specific model training
    # Here we assume a simple format or utilize the model's chat method if available
    # For InternVL-style models, typically:
    generation_config = dict(
        num_beams=1,
        max_new_tokens=512,
        do_sample=False,
    )

    # Using the chat interface if available or manual generation
    response = model.chat(
        tokenizer,
        pixel_values,
        question,
        generation_config,
        history=None,
        return_history=False
    )
    
    print(f"User: {question}")
    print(f"Assistant: {response}")

def run_reconstruction(args):
    """
    Run reconstruction using the R-Probe model.
    This demonstrates extracting features and running the OCR reconstruction head.
    """
    print(f"Loading Reconstruction model from {args.model_path}...")
    model = ReconModel.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True
    ).eval().cuda()

    pixel_values = load_image(args.image_path).to(torch.bfloat16).cuda()
    
    # 1. Extract visual features
    # shape: [num_patches, C, H, W] -> [num_patches, num_tokens, hidden_dim]
    pass_outputs = model.extract_feature(pixel_values)
    vit_embeds = pass_outputs
    
    # PE/Llama model typically projects vit_embeds via mlp1
    if hasattr(model, 'mlp1'):
        vit_embeds = model.mlp1(vit_embeds)
    
    # 2. Run Reconstruction Head
    if hasattr(model, 'ocr_recon') and model.ocr_recon is not None:
        print("Running OCR Reconstruction Head...")
        
        # We need to manually trigger the visualization logic similar to what's in the forward pass.
        # But efficiently, without re-running the full forward.
        
        # 2a. Run the recon head
        # The recon head takes input embeddings and returns reconstructed patches
        recon = model.ocr_recon(inputs_embeds=vit_embeds)
        
        # 2b. Prepare data for visualization
        # We need to set up the inputs required by `visualize_reconstruction`
        
        # Monkeypatch the saving function in the model module to avoid hardcoded paths/fonts issues
        import internvl.model.r_probe.modeling_llama_pe_chat as modeling_module
        
        def robust_save_pixel_values(pixel_values, save_dir, step=0, prefix=''):
            try:
                import numpy as np
                from PIL import Image
                os.makedirs(save_dir, exist_ok=True)
                pixel_np = pixel_values.detach().to(torch.float32).cpu().numpy()
                
                # ImageNet Mean/Std for denormalization
                IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
                IMAGENET_STD = np.array([0.229, 0.224, 0.225])
                
                # Check dimensions
                if pixel_np.shape[1] == 3: # NCHW
                    pass
                else: 
                     # handle potential permutation issues if any
                     pass

                for i in range(3):
                    pixel_np[:, i, :, :] = pixel_np[:, i, :, :] * IMAGENET_STD[i] + IMAGENET_MEAN[i]
                
                pixel_np = (pixel_np * 255).clip(0, 255).astype(np.uint8)
                pixel_np = np.transpose(pixel_np, (0, 2, 3, 1))  # [B, H, W, C]
                
                for i in range(pixel_np.shape[0]):
                    img = Image.fromarray(pixel_np[i])
                    save_path = os.path.join(save_dir, f"{prefix}_step{step}_{i}.png")
                    img.save(save_path)
                    print(f"Saved visualization to {save_path}")
            except Exception as e:
                print(f"Error saving visualization: {e}")

        # Replace the function in the module
        modeling_module.save_original_pixel_values = robust_save_pixel_values
        
        # Prepare arguments for visualize_reconstruction
        B, N, D = recon.shape
        
        # Ground truth patches (normalized)
        patch_size = int(model.patch_size / model.downsample_ratio)
        pixel_values_patches = model.to_patches_bcpp(pixel_values, patch_size)
        pixel_values_patches = torch.nn.functional.layer_norm(pixel_values_patches, pixel_values_patches.shape[-1:], eps=1e-6)
        
        # Shift inputs (simulating causal prediction: P_t predicts P_{t+1})
        shift_recon = recon[:, :-1, :]
        shift_pixel_labels = pixel_values_patches[:, 1:, :]
        
        # Indices and Weights
        # Assume standard order (no shuffling)
        # pixel_values_idx: [B, N, 1]
        pixel_values_idx = torch.arange(1, N + 1, device=pixel_values.device).view(1, N, 1).expand(B, -1, -1)
        shift_pixel_labels_idx = pixel_values_idx[:, 1:, :]
        
        # Reconstruct all valid positions
        shift_weights_gen = torch.ones(B, N-1, device=pixel_values.device)
        
        # Identity unshuffle index
        idx = torch.arange(N, device=pixel_values.device).expand(B, N)
        unshuffle_idx = torch.empty_like(idx)
        unshuffle_idx.scatter_(1, idx, torch.arange(N, device=idx.device).unsqueeze(0).expand(B, -1))
        
        # Call the visualization method
        print("Generating reconstruction visualization...")
        model.visualize_reconstruction(
            original_pixel_values=pixel_values,
            unshuffle_idx=unshuffle_idx,
            shift_pixel_labels=shift_pixel_labels,
            shift_recon=shift_recon,
            shift_weights_gen=shift_weights_gen,
            shift_pixel_labels_idx=shift_pixel_labels_idx,
            save_dir=args.output_dir,
            step=0
        )
        
        print(f"Reconstruction complete.")
        
    else:
        print("Error: Model does not have 'ocr_recon' head. Ensure config.use_ocr_recon is True.")

def main():
    parser = argparse.ArgumentParser(description="Inference and Reconstruction Demo")
    parser.add_argument("--mode", type=str, required=True, choices=["vlm", "recon"], 
                        help="Mode: 'vlm' for text generation, 'recon' for r-probe reconstruction")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint")
    parser.add_argument("--image_path", type=str, required=True, help="Path to input image")
    parser.add_argument("--prompt", type=str, default="Describe this image.", help="Prompt for VLM")
    parser.add_argument("--output_dir", type=str, default="./outputs", help="Directory to save reconstruction outputs")
    
    args = parser.parse_args()
    
    if args.mode == "vlm":
        run_vlm_inference(args)
    elif args.mode == "recon":
        run_reconstruction(args)

if __name__ == "__main__":
    main()
