import os
import argparse
import torch

from peft import PeftModel, get_peft_model_state_dict
from transformers import AutoModelForCausalLM, AutoTokenizer


def merge_adapters(
    base_model_name, adapter_path1, adapter_path2, output_path, alpha1=0.5, alpha2=0.5
):
    """
    Linearly combine weights from two adapters and save the result.

    Args:
        base_model: Path to the base model
        model_path1: Path to the first model
        model_path2: Path to the second model
        output_path: Path to save the merged model
        alpha: Weight for the first model (1-alpha for the second model)
    """

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, torch_dtype=torch.bfloat16
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    base_model_copy = AutoModelForCausalLM.from_pretrained(
        base_model_name, torch_dtype=torch.bfloat16
    )
    base_model_copy2 = AutoModelForCausalLM.from_pretrained(
        base_model_name, torch_dtype=torch.bfloat16
    )
    base_weights = base_model.state_dict()

    # Compute the difference between the base model and the first adapter
    if os.path.exists(os.path.join(adapter_path1, "adapter_config.json")):
        model = PeftModel.from_pretrained(
            base_model_copy,
            adapter_path1,
            adapter_name=adapter_path1,
            is_trainable=False,
            torch_dtype=torch.bfloat16,
        )
        merged_model = model.merge_and_unload()
    else:
        merged_model = AutoModelForCausalLM.from_pretrained(
            adapter_path1, torch_dtype=torch.bfloat16
        )

    diff_weights1 = {}
    for key in base_weights:
        if base_weights[key].dtype in [torch.int64, torch.uint8]:
            continue
        diff_weights1[key] = merged_model.state_dict()[key] - base_weights[key]

    if os.path.exists(os.path.join(adapter_path2, "adapter_config.json")):
        model = PeftModel.from_pretrained(
            base_model_copy2,
            adapter_path2,
            adapter_name=adapter_path2,
            is_trainable=False,
            torch_dtype=torch.bfloat16,
        )
        merged_model2 = model.merge_and_unload()
    else:
        merged_model2 = AutoModelForCausalLM.from_pretrained(
            adapter_path2, torch_dtype=torch.bfloat16
        )

    diff_weights2 = {}
    for key in base_weights:
        if base_weights[key].dtype in [torch.int64, torch.uint8]:
            continue
        diff_weights2[key] = merged_model2.state_dict()[key] - base_weights[key]
        
    # Compute cosine similarity between the two differences
    cosine_sims = {}
    for key in diff_weights1:
        cosine_similarity = torch.nn.functional.cosine_similarity(diff_weights1[key], diff_weights2[key], dim=0)
        print(f"Cosine similarity between {key} and {key}2: {cosine_similarity.mean().item()}")
        cosine_sims[key] = cosine_similarity.mean().item()
        
    # Print average cosine similarity
    avg_cosine_sim = sum(cosine_sims.values()) / len(cosine_sims)
    print(f"Average cosine similarity across all layers: {avg_cosine_sim:.4f}")

    # Apply a linear combination of the two differences
    combined_weights = {}
    for key in base_weights:
        if base_weights[key].dtype in [torch.int64, torch.uint8]:
            continue
        combined_weights[key] = (
            base_weights[key]
            + diff_weights1[key] * alpha1
            + diff_weights2[key] * alpha2
        )

    base_model.load_state_dict(combined_weights)

    # Create the output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    # Save the combined weights
    base_model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)


def merge_models(
    base_model_path, model_path1, model_path2, output_path, alpha1=0.5, alpha2=0.5
):
    """
    Combine task vectors from two models with respect to a base model.
    
    Task vector = fine_tuned_model - base_model
    Combined model = base_model + alpha1 * task_vector1 + alpha2 * task_vector2

    Args:
        base_model_path: Path to the base model
        model_path1: Path to the first fine-tuned model
        model_path2: Path to the second fine-tuned model
        output_path: Path to save the merged model
        alpha1: Weight for the first task vector
        alpha2: Weight for the second task vector
    """
    print(f"Loading base model from {base_model_path}")
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_path, torch_dtype=torch.bfloat16
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model_path)
    
    print(f"Loading model 1 from {model_path1}")
    model1 = AutoModelForCausalLM.from_pretrained(
        model_path1, torch_dtype=torch.bfloat16
    )

    print(f"Loading model 2 from {model_path2}")
    model2 = AutoModelForCausalLM.from_pretrained(
        model_path2, torch_dtype=torch.bfloat16
    )

    base_weights = base_model.state_dict()
    weights1 = model1.state_dict()
    weights2 = model2.state_dict()
    
    # Get common keys across all models
    common_keys = set(base_weights.keys()) & set(weights1.keys()) & set(weights2.keys())
    
    # Compute task vectors (difference from base model)
    print("Computing task vectors...")
    task_vector1 = {}
    task_vector2 = {}
    
    for key in common_keys:
        if base_weights[key].is_floating_point():
            task_vector1[key] = weights1[key] - base_weights[key]
            task_vector2[key] = weights2[key] - base_weights[key]

    # Compute cosine similarity between the two task vectors
    cosine_sims = {}
    for key in task_vector1:
        if task_vector1[key].is_floating_point() and task_vector2[key].is_floating_point():
            # flatten tensors for accurate cosine similarity
            tv1_flat = task_vector1[key].flatten()
            tv2_flat = task_vector2[key].flatten()
            cosine_sims[key] = torch.nn.functional.cosine_similarity(tv1_flat, tv2_flat, dim=0).item()
    
    if cosine_sims:
        avg_cosine_sim = sum(cosine_sims.values()) / len(cosine_sims)
        print(f"Average cosine similarity between task vectors: {avg_cosine_sim:.4f}")

    # Combine task vectors and apply to base model
    print(f"Combining task vectors with weights α1={alpha1}, α2={alpha2}")
    combined_weights = {}
    
    for key in base_weights:
        if key in task_vector1 and key in task_vector2:
            if base_weights[key].is_floating_point():
                # Apply task vector arithmetic: base + α1*tv1 + α2*tv2
                combined_weights[key] = (
                    base_weights[key] + 
                    alpha1 * task_vector1[key] + 
                    alpha2 * task_vector2[key]
                )
            else:
                # For non-floating point tensors, use base model weights
                combined_weights[key] = base_weights[key]
        else:
            print(f"Warning: key {key} not found in all models. Using base model value.")
            combined_weights[key] = base_weights[key]

    base_model.load_state_dict(combined_weights)

    # Create the output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    # Save the combined weights
    print(f"Saving merged model to {output_path}")
    base_model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)


def merge_adapters_svd(
    base_model_name, adapter_path1, adapter_path2, output_path, alpha1=0.5, alpha2=0.5
):
    """
    Merge two LoRA adapters using PEFT's built-in add_weighted_adapter method.
    
    This function combines two LoRA adapters by:
    1. Loading the base model
    2. Loading both adapters
    3. Using PEFT's add_weighted_adapter to combine them
    4. Saving the result as a new LoRA adapter
    
    Args:
        base_model_name: Name/path of the base model
        adapter_path1: Path to the first LoRA adapter
        adapter_path2: Path to the second LoRA adapter
        output_path: Path to save the merged LoRA adapter
        alpha1: Weight for the first adapter (default: 0.5)
        alpha2: Weight for the second adapter (default: 0.5)
    """
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name, torch_dtype=torch.bfloat16, device_map="auto"
    )
    
    print(f"Loaded base model: {base_model_name}")
    # Load the first adapter
    model = PeftModel.from_pretrained(
        base_model, 
        adapter_path1, 
        adapter_name="adapter1",
        is_trainable=False, 
        torch_dtype=torch.bfloat16
    )
    print(f"Loaded first adapter: {adapter_path1}")
    # Load the second adapter
    model.load_adapter(adapter_path2, adapter_name="adapter2", is_trainable=False, torch_dtype=torch.bfloat16)
    print(f"Loaded second adapter: {adapter_path2}")
    
    # Combine the adapters using add_weighted_adapter
    model.add_weighted_adapter(
        adapters=["adapter1", "adapter2"],
        weights=[alpha1, alpha2],
        adapter_name="merged_adapter",
        combination_type="svd",
        svd_driver="gesvd"
    )
    print(f"Combined adapters using SVD")
    # Set the merged adapter as active
    model.set_adapter("merged_adapter")
    
    # Create the output directory if it doesn't exist
    os.makedirs(output_path, exist_ok=True)
    
    # Save the merged adapter
    model.save_pretrained(output_path, selected_adapters=["merged_adapter"])
    
    print(f"Merged LoRA adapter saved to: {output_path}")


def merge_individual_adapters(base_model_name, adapters_dir, output_dir):
    """
    Merge each adapter individually with the base model and save as full weights.
    
    Args:
        base_model_name: Name/path of the base model
        adapters_dir: Directory containing the adapters
        output_dir: Directory to save the merged models
    """
    
    # Get list of adapter directories
    adapter_dirs = [d for d in os.listdir(adapters_dir) 
                   if os.path.isdir(os.path.join(adapters_dir, d)) and
                   os.path.exists(os.path.join(adapters_dir, d, "adapter_config.json"))]
    
    print(f"Found {len(adapter_dirs)} adapters to merge: {adapter_dirs}")
    
    # Load tokenizer once (will be the same for all merged models)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name)
    
    for adapter_name in adapter_dirs:
        adapter_path = os.path.join(adapters_dir, adapter_name)
        output_path = os.path.join(output_dir, adapter_name)
        
        print(f"\nMerging adapter: {adapter_name}")
        print(f"Adapter path: {adapter_path}")
        print(f"Output path: {output_path}")
        
        # Load base model fresh for each adapter
        print(f"Loading base model: {base_model_name}")
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name, 
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
        
        # Load the adapter onto the base model
        print(f"Loading adapter: {adapter_path}")
        model_with_adapter = PeftModel.from_pretrained(
            base_model,
            adapter_path,
            is_trainable=False,
            torch_dtype=torch.bfloat16,
        )
        
        # Merge and unload the adapter to get full weights
        print("Merging adapter with base model...")
        merged_model = model_with_adapter.merge_and_unload()
        
        # Create output directory
        os.makedirs(output_path, exist_ok=True)
        
        # Save the merged model with full weights
        print(f"Saving merged model to: {output_path}")
        merged_model.save_pretrained(output_path)
        tokenizer.save_pretrained(output_path)
        
        print(f"✓ Successfully merged and saved: {adapter_name}")
        
        # Clean up memory
        del base_model, model_with_adapter, merged_model
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    print(f"\n✓ All adapters merged successfully! Output directory: {output_dir}")


def main():
    parser = argparse.ArgumentParser(
        description="Merge two adapter models with linear interpolation"
    )
    parser.add_argument(
        "--adapter1", type=str, required=True, help="Path to the first adapter"
    )
    parser.add_argument(
        "--adapter2", type=str, required=True, help="Path to the second adapter"
    )
    parser.add_argument(
        "--output", type=str, required=True, help="Path to save the merged model"
    )
    parser.add_argument(
        "--alpha1",
        type=float,
        default=0.5,
        help="Weight for the first model (default: 0.5)",
    )
    parser.add_argument(
        "--alpha2",
        type=float,
        default=0.5,
        help="Weight for the second model (default: 0.5)",
    )
    parser.add_argument(
        "--base_model", type=str, required=True, help="Path to the base model"
    )

    args = parser.parse_args()

    # Merge the adapters
    merge_models(
        args.base_model,
        args.adapter1,
        args.adapter2,
        args.output,
        args.alpha1,
        args.alpha2,
    )
    
def main_merge_adapters():
    parser = argparse.ArgumentParser(
        description="Merge two adapter models with linear interpolation"
    )
    parser.add_argument(
        "--adapters_dir", type=str, required=True, help="Path to the directory containing the adapters"
    )
    parser.add_argument(
        "--output_dir", type=str, required=True, help="Path to save the merged models"
    )
    parser.add_argument(
        "--base_model", type=str, required=True, help="Path to the base model"
    )
    args = parser.parse_args()
    
    merge_individual_adapters(args.base_model, args.adapters_dir, args.output_dir)
    


if __name__ == "__main__":
    main()
