import argparse
import json
import os
import torch
import types
from transformers import AutoModelForCausalLM, AutoTokenizer

from block_hadamard_hira import (
    BlockHadamardHiRAConfig, 
    get_block_hadamard_hira_model, 
    apply_block_hadamard_hira,
    get_adapter_state_dict,
    set_adapter_state_dict
)


def load_and_merge_block_hadamard_hira_adapter(base_model_name, adapter_path, output_path):
    """
    Load a Block Hadamard HiRA adapter, apply it to a base model, merge the weights, and save the result.
    
    Args:
        base_model_name: Name or path of the base model
        adapter_path: Path to the saved Block Hadamard HiRA adapter directory
        output_path: Path to save the merged model
    """
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    print(f"Loading base model: {base_model_name}")

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, 
        torch_dtype=torch.bfloat16,
        device_map="auto"  # Let transformers handle device placement
    )
    
    # Ensure model is on GPU
    if torch.cuda.is_available():
        base_model = base_model.to(device)
    
    adapter_name = "block_hira_cr"
    
    print(f"Loading Block Hadamard HiRA adapter from: {adapter_path}")

    # Load the adapter configuration
    adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
    if os.path.exists(adapter_config_path):
        with open(adapter_config_path, "r") as f:
            config_dict = json.load(f)
    else:
        # Try loading from block_hadamard_hira_config.json (alternative location)
        block_hira_config_path = os.path.join(os.path.dirname(adapter_path), "block_hadamard_hira_config.json")
        if os.path.exists(block_hira_config_path):
            with open(block_hira_config_path, "r") as f:
                config_dict = json.load(f)
        else:
            raise FileNotFoundError(f"Neither adapter_config.json nor block_hadamard_hira_config.json found in {adapter_path} or parent directory")
    
    print("Creating Block Hadamard HiRA configuration")
    # Filter relevant config parameters for Block Hadamard HiRA
    block_hira_params = {
        'r', 'alpha', 'dropout', 'target_modules', 
        'bias', 'init_lora_weights', 'num_blocks', 'block_arrangement'
    }
    filtered_config = {k: v for k, v in config_dict.items() if k in block_hira_params}
    
    config = BlockHadamardHiRAConfig(**filtered_config)
    
    # Apply Block Hadamard HiRA to the model
    print("Applying Block Hadamard HiRA to the model")
    model = get_block_hadamard_hira_model(base_model, config, adapter_name=adapter_name)
    
    # Ensure model is on GPU after Block Hadamard HiRA application
    if torch.cuda.is_available():
        model = model.to(device)

    # Load the adapter weights
    adapter_model_path = os.path.join(adapter_path, "adapter_model.bin")
    if not os.path.exists(adapter_model_path):
        raise FileNotFoundError(f"Adapter weights not found at {adapter_model_path}")
    
    print(f"Loading Block Hadamard HiRA adapter weights to {device}")
    adapter_state_dict = torch.load(
        adapter_model_path, 
        map_location=device  # Load directly to GPU
    )
    
    # Ensure all adapter weights are on the correct device
    for key, tensor in adapter_state_dict.items():
        adapter_state_dict[key] = tensor.to(device)
    
    # Set the adapter weights
    set_adapter_state_dict(model, adapter_state_dict, adapter_name)
    
    if model is None:
        raise ValueError("Failed to load the Block Hadamard HiRA adapter model")
    
    # Print Block Hadamard HiRA parameter information
    print("📊 Block Hadamard HiRA Parameter Information:")
    block_hira_param_count = 0
    total_blocks = 0
    for name, module in model.named_modules():
        if hasattr(module, 'block_lora_A') and adapter_name in module.block_lora_A:
            block_lora_a_params = module.block_lora_A[adapter_name].numel()
            block_lora_b_params = module.block_lora_B[adapter_name].numel()
            total_params = block_lora_a_params + block_lora_b_params
            block_hira_param_count += total_params
            total_blocks += module.num_blocks * module.num_blocks
            print(f"   {name}: A={module.block_lora_A[adapter_name].shape} ({block_lora_a_params:,}), B={module.block_lora_B[adapter_name].shape} ({block_lora_b_params:,})")
            print(f"     └── Blocks: {module.num_blocks}×{module.num_blocks} = {module.num_blocks * module.num_blocks} blocks")
    print(f"   Total Block Hadamard HiRA parameters: {block_hira_param_count:,}")
    print(f"   Total blocks across all layers: {total_blocks}")
    
    # Merge weights using Block Hadamard HiRA's merge mechanism
    print("Merging Block Hadamard HiRA adapter weights into the base model")
    
    try:
        # Ensure all model parameters are on the same device before merging
        for name, param in model.named_parameters():
            if param.device != device:
                param.data = param.data.to(device)
        
        # Manual merge approach - directly merge each Block Hadamard HiRA layer
        print("🔄 Using manual Block Hadamard HiRA merge method...")
        merged_layers = 0
        for name, module in model.named_modules():
            if hasattr(module, 'merge') and hasattr(module, 'block_lora_A'):
                if adapter_name in module.block_lora_A:
                    print(f"   Merging Block Hadamard HiRA layer: {name}")
                    # Ensure all tensors in the module are on the correct device
                    for param_name, param in module.named_parameters():
                        if param.device != device:
                            param.data = param.data.to(device)
                    
                    # Perform block-wise merge
                    module.merge()
                    merged_layers += 1
        
        print(f"   ✅ Successfully merged {merged_layers} Block Hadamard HiRA layers")
        
        # Create a clean model from the base model
        print("Creating clean merged model...")
        merged_model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        if torch.cuda.is_available():
            merged_model = merged_model.to(device)
        
        # Get the merged state dict from the Block Hadamard HiRA model
        merged_state_dict = model.state_dict()
        
        # Filter out Block Hadamard HiRA-specific parameters and fix key names for the clean model
        clean_state_dict = {}
        copied_weights = 0
        for key, value in merged_state_dict.items():
            # Skip Block Hadamard HiRA adapter parameters and any module-specific parameters
            if not any(adapter_key in key for adapter_key in ['block_lora_A', 'block_lora_B', 'adapter']):
                # Fix key names: remove 'base_layer.' prefix from Block Hadamard HiRA wrapped layers
                clean_key = key.replace('.base_layer.', '.')
                clean_state_dict[clean_key] = value.to(device)
                if 'weight' in clean_key or 'bias' in clean_key:
                    copied_weights += 1
        
        # Load the merged weights into the clean model
        missing_keys, unexpected_keys = merged_model.load_state_dict(clean_state_dict, strict=False)
        if missing_keys:
            print(f"   Warning: Missing keys: {missing_keys[:5]}...")  # Show first 5
        if unexpected_keys:
            print(f"   Warning: Unexpected keys: {unexpected_keys[:5]}...")  # Show first 5
        
        print(f"   ✅ Successfully copied {copied_weights} weight tensors to clean model")
        print("✅ Successfully merged using manual Block Hadamard HiRA method")
        
    except Exception as e:
        print(f"❌ Block Hadamard HiRA merge failed: {e}")
        import traceback
        traceback.print_exc()
        raise RuntimeError(f"Block Hadamard HiRA merge failed: {e}")
    
    # Ensure final model is on CPU for saving
    merged_model = merged_model.to("cpu")
    
    print(f"Saving merged model to: {output_path}")
    merged_model.save_pretrained(output_path)
    
    # Update the model config to include base_model_name_or_path and Block Hadamard HiRA info
    config_path = os.path.join(output_path, "config.json")
    if os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config = json.load(f)
        
        # Add the base model name for proper loading
        config['base_model_name_or_path'] = base_model_name
        
        # Add Block Hadamard HiRA merge information
        config['merged_adapters'] = {
            'method': 'Block_Hadamard_HiRA',
            'adapter_name': adapter_name,
            'num_blocks': getattr(model.peft_config[adapter_name], 'num_blocks', 4),
            'block_arrangement': getattr(model.peft_config[adapter_name], 'block_arrangement', 'square'),
            'r': getattr(model.peft_config[adapter_name], 'r', 32),
            'alpha': getattr(model.peft_config[adapter_name], 'alpha', 32)
        }
        
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)
        print(f"   Updated config with base_model_name_or_path: {base_model_name}")
        print(f"   Added Block Hadamard HiRA merge info to config")
    
    print("Saving tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    tokenizer.save_pretrained(output_path)
    
    # Cleanup
    del model
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    print("✅ Block Hadamard HiRA merge and save completed successfully!")
    return output_path


def main():
    parser = argparse.ArgumentParser(description="Merge Block Hadamard HiRA adapter weights with base model and save")
    parser.add_argument("--base_model", type=str, required=True, 
                        help="Base model name or path")
    parser.add_argument("--adapter_path", type=str, required=True, 
                        help="Path to the saved Block Hadamard HiRA adapter directory")
    parser.add_argument("--output_path", type=str, required=True, 
                        help="Path to save the merged model")
    
    args = parser.parse_args()
    
    print("🚀 Starting Block Hadamard HiRA merge process...")
    print(f"📊 Base model: {args.base_model}")
    print(f"🧱 Block Hadamard HiRA adapter: {args.adapter_path}")
    print(f"💾 Output path: {args.output_path}")
    
    load_and_merge_block_hadamard_hira_adapter(args.base_model, args.adapter_path, args.output_path)
    
    print("🎉 Block Hadamard HiRA merge process completed!")


if __name__ == "__main__":
    main()