#!/usr/bin/env python3
"""
Improved InternVL3 Custom to HuggingFace conversion script.
Fixes token/feature mismatch issues by properly copying all configuration files.
"""

import json
import os
import argparse
import shutil
from collections import OrderedDict
from copy import deepcopy
from pathlib import Path

import torch
from safetensors import safe_open
from transformers import (AutoConfig, AutoModel, AutoModelForImageTextToText,
                          AutoTokenizer, AutoProcessor)
from huggingface_hub import hf_hub_download


def copy_all_config_files(hf_path, save_path):
    """
    Copy ALL configuration files from HF reference model.
    This ensures complete compatibility with the official model.
    
    Args:
        hf_path: Path to HF reference model (local or HF hub model name)
        save_path: Directory to copy configs to
    """
    print(f"\n🔧 Copying ALL HF reference configurations...")
    
    # Comprehensive list of config files to copy
    config_files = [
        'config.json',                    # Main model configuration
        'preprocessor_config.json',       # Image preprocessing settings
        'processor_config.json',           # Processor configuration
        'tokenizer_config.json',          # Tokenizer settings
        'special_tokens_map.json',        # Special token mappings
        'generation_config.json',         # Generation parameters
        'chat_template.json',             # Chat template if exists
        'tokenizer.json',                 # Fast tokenizer vocabulary
        'added_tokens.json',              # Additional tokens
    ]
    
    copied_files = []
    missing_files = []
    
    for config_file in config_files:
        try:
            if os.path.exists(hf_path) and os.path.isdir(hf_path):
                # Local HF model path
                src_file = os.path.join(hf_path, config_file)
                if os.path.exists(src_file):
                    dst_file = os.path.join(save_path, config_file)
                    shutil.copy2(src_file, dst_file)
                    copied_files.append(config_file)
                    print(f"  ✅ Copied {config_file} from local path")
                else:
                    missing_files.append(config_file)
                    print(f"  ⚠️  {config_file} not found in local path")
            else:
                # HF hub model name - download the config
                try:
                    local_path = hf_hub_download(
                        hf_path, 
                        config_file, 
                        local_dir=save_path, 
                        local_dir_use_symlinks=False
                    )
                    copied_files.append(config_file)
                    print(f"  ✅ Downloaded {config_file} from HF hub")
                except Exception as e:
                    missing_files.append(config_file)
                    if "404 Client Error" not in str(e):
                        print(f"  ⚠️  Could not download {config_file}: {e}")
                    
        except Exception as e:
            missing_files.append(config_file)
            print(f"  ❌ Error copying {config_file}: {e}")
    
    # Special handling for config.json - merge instead of overwrite
    config_path = os.path.join(save_path, 'config.json')
    if os.path.exists(config_path):
        print(f"\n🔄 Merging config.json to preserve model weights mapping...")
        merge_config_files(hf_path, save_path)
    
    print(f"\n📊 Configuration copy summary:")
    print(f"  ✅ Successfully copied: {len(copied_files)} files")
    if missing_files:
        print(f"  ⚠️  Missing (optional): {len(missing_files)} files")
    
    return copied_files, missing_files


def merge_config_files(hf_path, save_path):
    """
    Merge config.json files to preserve weight mappings while updating other settings.
    """
    try:
        # Load the converted config (has correct weight mappings)
        converted_config_path = os.path.join(save_path, 'config.json')
        with open(converted_config_path, 'r') as f:
            converted_config = json.load(f)
        
        # Load the reference config (has correct processing settings)
        if os.path.exists(hf_path) and os.path.isdir(hf_path):
            ref_config_path = os.path.join(hf_path, 'config.json')
            with open(ref_config_path, 'r') as f:
                ref_config = json.load(f)
        else:
            # Download from HF hub
            ref_config = AutoConfig.from_pretrained(hf_path, trust_remote_code=True).to_dict()
        
        # Key fields to preserve from converted config (weight-related)
        preserve_from_converted = [
            'architectures',
            'torch_dtype', 
            'transformers_version',
        ]
        
        # Key fields to update from reference (processing-related)
        update_from_reference = [
            'num_image_tokens',
            'image_size',
            'patch_size',
            'vision_config',
            'select_layer',
            'ps_version',
            'downsample_ratio',
            'mlp_depth',
            'image_seq_length',
            'vision_feature_layer',
            'vision_feature_select_strategy',
            'template',
            'use_thumbnail',
            'min_dynamic_patch',
            'max_dynamic_patch',
            'use_dynamic_image_aspect_ratio',
        ]
        
        # Create merged config
        merged_config = converted_config.copy()
        
        # Update specific fields from reference
        for field in update_from_reference:
            if field in ref_config:
                merged_config[field] = ref_config[field]
                print(f"    Updated {field}: {ref_config[field]}")
        
        # Also merge vision_config if it exists
        if 'vision_config' in ref_config and isinstance(ref_config['vision_config'], dict):
            if 'vision_config' not in merged_config:
                merged_config['vision_config'] = {}
            merged_config['vision_config'].update(ref_config['vision_config'])
            print(f"    Updated vision_config with {len(ref_config['vision_config'])} fields")
        
        # Save the merged config
        with open(converted_config_path, 'w') as f:
            json.dump(merged_config, f, indent=2)
        
        print(f"  ✅ Successfully merged config.json")
        
    except Exception as e:
        print(f"  ❌ Error merging configs: {e}")


def validate_conversion(save_path):
    """
    Validate the converted model by checking for common issues.
    """
    print(f"\n🔍 Validating converted model...")
    
    issues = []
    warnings = []
    
    # Check config.json
    config_path = os.path.join(save_path, 'config.json')
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Check for critical fields
        critical_fields = ['num_image_tokens', 'image_size', 'patch_size']
        for field in critical_fields:
            if field not in config:
                issues.append(f"Missing critical field: {field}")
        
        # Check vision_config
        if 'vision_config' in config:
            vision_config = config['vision_config']
            if isinstance(vision_config, dict):
                if 'image_size' in vision_config:
                    print(f"  ✅ Vision image_size: {vision_config['image_size']}")
                if 'patch_size' in vision_config:
                    print(f"  ✅ Vision patch_size: {vision_config['patch_size']}")
                if 'num_patches' in vision_config:
                    print(f"  ✅ Vision num_patches: {vision_config['num_patches']}")
        else:
            warnings.append("No vision_config found")
    else:
        issues.append("config.json not found")
    
    # Check processor configs
    processor_config_path = os.path.join(save_path, 'processor_config.json')
    if not os.path.exists(processor_config_path):
        issues.append("processor_config.json not found")
    
    preprocessor_config_path = os.path.join(save_path, 'preprocessor_config.json')
    if not os.path.exists(preprocessor_config_path):
        issues.append("preprocessor_config.json not found")
    else:
        with open(preprocessor_config_path, 'r') as f:
            preproc_config = json.load(f)
        if 'image_size' in preproc_config:
            print(f"  ✅ Preprocessor image_size: {preproc_config['image_size']}")
        if 'size' in preproc_config:
            print(f"  ✅ Preprocessor size: {preproc_config['size']}")
    
    # Check tokenizer configs
    tokenizer_config_path = os.path.join(save_path, 'tokenizer_config.json')
    if not os.path.exists(tokenizer_config_path):
        warnings.append("tokenizer_config.json not found (may use default)")
    
    # Report results
    if issues:
        print(f"\n❌ Critical issues found:")
        for issue in issues:
            print(f"  - {issue}")
    
    if warnings:
        print(f"\n⚠️  Warnings:")
        for warning in warnings:
            print(f"  - {warning}")
    
    if not issues:
        print(f"\n✅ Validation passed!")
    
    return len(issues) == 0


def compute_l2_distance(model1, model2):
    """Compute L2 distance between two models for verification."""
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    total_l2 = 0.0
    total_params = 0

    common_keys = set(state_dict1.keys()) & set(state_dict2.keys())

    for key in common_keys:
        t1 = state_dict1[key].float().cpu()
        t2 = state_dict2[key].float().cpu()

        if t1.shape != t2.shape:
            print(f"⚠️ Shape mismatch at key: {key}, skipping.")
            continue

        diff = t1 - t2
        l2 = torch.norm(diff, p=2)
        total_l2 += l2.item()
        total_params += diff.numel()

    print(f"\n✅ Total L2 distance: {total_l2:.6f}")
    print(f"✅ Average per-parameter L2: {total_l2 / total_params:.8f}" if total_params > 0 else '⚠️ No matching parameters.')

    return total_l2


def convert_keys_to_hf(custom_state_dict):
    """Convert custom model keys to HuggingFace format."""
    new_state_dict = OrderedDict()
    qkv_split_buffer = {}

    for key, value in custom_state_dict.items():
        # === 1. mlp1.* → multi_modal_projector
        if key.startswith('mlp1.0.'):
            new_key = 'model.' + key.replace('mlp1.0.', 'multi_modal_projector.layer_norm.')
        elif key.startswith('mlp1.1.'):
            new_key = 'model.' + key.replace('mlp1.1.', 'multi_modal_projector.linear_1.')
        elif key.startswith('mlp1.3.'):
            new_key = 'model.' + key.replace('mlp1.3.', 'multi_modal_projector.linear_2.')

        # === 2. embeddings ===
        elif key == 'vision_model.embeddings.class_embedding':
            new_key = 'model.vision_tower.embeddings.cls_token'
        elif key.startswith('vision_model.embeddings.patch_embedding'):
            new_key = 'model.' + key.replace(
                'vision_model.embeddings.patch_embedding',
                'vision_tower.embeddings.patch_embeddings.projection'
            )
        elif key == 'vision_model.embeddings.position_embedding':
            new_key = 'model.vision_tower.embeddings.position_embeddings'

        # === 3. encoder ===
        elif key.startswith('vision_model.encoder.layers.'):
            parts = key.split('.')
            layer_id = parts[3]
            suffix = '.'.join(parts[4:])
            base = f"model.vision_tower.encoder.layer.{layer_id}."

            if suffix.startswith('attn.qkv.weight'):
                qkv_split_buffer[(layer_id, 'weight')] = value
                continue
            elif suffix.startswith('attn.qkv.bias'):
                qkv_split_buffer[(layer_id, 'bias')] = value
                continue
            elif suffix.startswith('attn.proj.'):
                new_key = base + 'attention.projection_layer.' + suffix.split('.')[-1]
            elif suffix.startswith('norm1.'):
                new_key = base + 'layernorm_before.' + suffix.split('.')[-1]
            elif suffix.startswith('norm2.'):
                new_key = base + 'layernorm_after.' + suffix.split('.')[-1]
            elif suffix == 'ls1':
                new_key = base + 'lambda_1'
            elif suffix == 'ls2':
                new_key = base + 'lambda_2'
            else:
                new_key = base + suffix

        # === 4. language_model.model. → language_model.
        elif key == 'language_model.lm_head.weight' or key == 'language_model.model.lm_head.weight':
            new_key = 'lm_head.weight'

        elif key.startswith('language_model.model.'):
            new_key = 'model.' + key.replace('language_model.model.', 'language_model.')

        # === 5. already has model. prefix or default
        elif key.startswith('model.'):
            new_key = key

        else:
            new_key = 'model.' + key

        new_state_dict[new_key] = value

    # === 6. Split QKV ===
    for (layer_id, typ), tensor in qkv_split_buffer.items():
        d = tensor.shape[0] // 3
        q, k, v = tensor[:d], tensor[d:2 * d], tensor[2 * d:]
        base = f"model.vision_tower.encoder.layer.{layer_id}.attention."
        if typ == 'weight':
            new_state_dict[base + 'q_proj.weight'] = q
            new_state_dict[base + 'k_proj.weight'] = k
            new_state_dict[base + 'v_proj.weight'] = v
        else:
            new_state_dict[base + 'q_proj.bias'] = q
            new_state_dict[base + 'k_proj.bias'] = k
            new_state_dict[base + 'v_proj.bias'] = v

    return new_state_dict


def main():
    parser = argparse.ArgumentParser(description="Convert InternVL3 custom weights to HuggingFace format (FIXED version).")
    parser.add_argument('--custom_path', type=str, required=True, help='Path to original safetensors checkpoint folder')
    parser.add_argument('--hf_path', type=str, required=True, help='Path to reference HuggingFace model for configs')
    parser.add_argument('--save_path', type=str, required=True, help='Path to save the converted model')
    parser.add_argument('--skip_gpu', action='store_true', help='Skip GPU usage for conversion (use CPU only)')
    parser.add_argument('--skip_validation', action='store_true', help='Skip post-conversion validation')
    args = parser.parse_args()

    print(f"🚀 Starting InternVL3 Custom → HuggingFace conversion (FIXED)")
    print(f"📁 Custom model: {args.custom_path}")
    print(f"🔗 HF reference: {args.hf_path}")
    print(f"💾 Save to: {args.save_path}")

    # Create save directory
    os.makedirs(args.save_path, exist_ok=True)

    # Step 1: Copy ALL configuration files FIRST
    print(f"\n" + "="*60)
    print("STEP 1: Copy configuration files")
    print("="*60)
    copied_files, missing_files = copy_all_config_files(args.hf_path, args.save_path)

    # Step 2: Load and convert model weights
    print(f"\n" + "="*60)
    print("STEP 2: Convert model weights")
    print("="*60)
    
    # Load model configuration from the SAVED path (which now has correct configs)
    config = AutoConfig.from_pretrained(args.save_path, trust_remote_code=True)
    
    # Choose device based on argument
    device = 'cpu' if args.skip_gpu else 'cuda'
    model = AutoModelForImageTextToText.from_config(config, trust_remote_code=True).to(device)

    # Load custom safetensor weights
    checkpoint_paths = [os.path.join(args.custom_path, f) for f in os.listdir(args.custom_path) if f.endswith('.safetensors')]
    print(f"🔍 Found checkpoint files: {[os.path.basename(p) for p in checkpoint_paths]}")

    model_state_dict_hf = {}
    for checkpoint_path in checkpoint_paths:
        with safe_open(checkpoint_path, framework='pt') as f:
            for k in f.keys():
                model_state_dict_hf[k] = f.get_tensor(k)

    # Convert key naming style
    print(f"🔄 Converting {len(model_state_dict_hf)} weight keys to HF format...")
    model_state_dict = convert_keys_to_hf(model_state_dict_hf)

    # Load weights into model
    missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)
    
    if missing_keys:
        print(f"\n⚠️  Missing keys ({len(missing_keys)}):")
        for key in missing_keys[:10]:  # Show first 10
            print(f"  - {key}")
        if len(missing_keys) > 10:
            print(f"  ... and {len(missing_keys) - 10} more")
    
    if unexpected_keys:
        print(f"\n⚠️  Unexpected keys ({len(unexpected_keys)}):")
        for key in unexpected_keys[:10]:  # Show first 10
            print(f"  - {key}")
        if len(unexpected_keys) > 10:
            print(f"  ... and {len(unexpected_keys) - 10} more")

    # Step 3: Compare with reference model (optional)
    if not args.skip_validation:
        print(f"\n" + "="*60)
        print("STEP 3: Validate conversion")
        print("="*60)
        
        try:
            print(f"📊 Loading reference model for comparison...")
            model_compare = AutoModelForImageTextToText.from_pretrained(args.hf_path, trust_remote_code=True).to(device)
            compute_l2_distance(model, model_compare)
            del model_compare  # Free memory
        except Exception as e:
            print(f"⚠️  Could not compare with reference: {e}")

    # Step 4: Save the converted model
    print(f"\n" + "="*60)
    print("STEP 4: Save converted model")
    print("="*60)
    
    print(f"💾 Saving model weights...")
    model.save_pretrained(args.save_path)

    # Save tokenizer (if not already copied)
    if 'tokenizer.model' not in copied_files and 'tokenizer.json' not in copied_files:
        print(f"💾 Saving tokenizer...")
        try:
            tokenizer = AutoTokenizer.from_pretrained(args.hf_path, trust_remote_code=True)
            tokenizer.save_pretrained(args.save_path)
        except Exception as e:
            print(f"⚠️  Could not save tokenizer: {e}")

    # Step 5: Validate the conversion
    if not args.skip_validation:
        print(f"\n" + "="*60)
        print("STEP 5: Final validation")
        print("="*60)
        
        if validate_conversion(args.save_path):
            # Test loading
            print(f"\n🧪 Testing converted model...")
            try:
                # Test processor loading
                processor = AutoProcessor.from_pretrained(args.save_path, trust_remote_code=True)
                print(f"  ✅ Processor loads successfully: {type(processor).__name__}")
                
                # Test model loading
                test_model = AutoModel.from_pretrained(args.save_path, trust_remote_code=True, device_map='cpu')
                print(f"  ✅ Model loads successfully: {type(test_model).__name__}")
                
                # Test basic functionality with image
                from PIL import Image
                dummy_image = Image.new('RGB', (448, 448), color='red')
                test_input = processor(text="Test prompt", images=dummy_image, return_tensors="pt")
                
                # Check for token/feature consistency
                if 'input_ids' in test_input:
                    input_len = test_input['input_ids'].shape[-1]
                    print(f"  ✅ Input IDs shape: {test_input['input_ids'].shape}")
                if 'pixel_values' in test_input:
                    print(f"  ✅ Pixel values shape: {test_input['pixel_values'].shape}")
                
                with torch.no_grad():
                    outputs = test_model(**test_input)
                print(f"  ✅ Model inference works!")
                
            except Exception as e:
                print(f"  ❌ Post-conversion test failed: {e}")
                import traceback
                traceback.print_exc()

    # Final summary
    print(f"\n" + "="*60)
    print("🎉 CONVERSION COMPLETED")
    print("="*60)
    print(f"📋 Summary:")
    print(f"  • Configuration files: {len(copied_files)} copied")
    print(f"  • Model weights: Converted successfully")
    print(f"  • Output location: {args.save_path}")
    print(f"\n✨ The converted model should now work correctly with TRL training!")
    print(f"💡 If you still encounter token/feature mismatches, run the debug script:")
    print(f"   python debug_internvl3_config.py --official {args.hf_path} --converted {args.save_path}")


if __name__ == '__main__':
    main()
