import argparse
import json
import logging
import os

import torch
from tqdm import tqdm
from transformers import set_seed
from utils import load_model_and_tokenizer

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO)
logger = logging.getLogger(__name__)


def parse_args():
    parser = argparse.ArgumentParser(description="Optimize fingerprint input")
    parser.add_argument("--model1_path", type=str, help="Path to the pre-trained model")
    parser.add_argument("--model2_path", type=str, help="Path to the pre-trained model")
    parser.add_argument("--cache_dir", type=str, default="./lm_cache", help="Cache directory for the model")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--save_dir", type=str, default="model_distances", help="Directory to save results")

    return parser.parse_args()


def main():
    args = parse_args()

    logger.info(f"Arguments: {args}")

    # Set seed
    set_seed(args.seed)

    # Load models and tokenizers
    model1, _ = load_model_and_tokenizer(
        model_path=args.model1_path,
        tokenizer_name=None,
        config_name=None,
        model_revision="main",
        cache_dir=args.cache_dir,
        low_cpu_mem_usage=False,
        trust_remote_code=True,
    )
    model2, _ = load_model_and_tokenizer(
        model_path=args.model2_path,
        tokenizer_name=None,
        config_name=None,
        model_revision="main",
        cache_dir=args.cache_dir,
        low_cpu_mem_usage=False,
        trust_remote_code=True,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model1.to(device)
    model2.to(device)

    # Compare weights
    distances = {}
    total_distance = 0.0
    param_count = 0

    for (name1, param1), (name2, param2) in tqdm(zip(model1.named_parameters(), model2.named_parameters()), total=len(model1.state_dict()), desc="Comparing weights"):
        if param1.size() != param2.size():
            logger.warning(f"Skipping {name1} and {name2} because of different sizes")
            continue
        elif name1 != name2:
            logger.warning(f"Skipping {name1} and {name2} because of different names")
            continue
        else:
            if param1.dim() == 1:
                # Bias
                distance = torch.linalg.vector_norm(param1 - param2, ord=2).item()
            else:
                # Weight
                distance = torch.linalg.matrix_norm(param1 - param2, ord=2).item()
            distances[name1] = distance
            total_distance += distance
            param_count += 1

    average_distance = total_distance / param_count if param_count > 0 else 0.0

    results = {
        "average_distance": average_distance,
        "total_distance": total_distance,
        "param_count": param_count,
        "distances": distances,
    }

    # Save results
    os.makedirs(args.save_dir, exist_ok=True)
    with open(os.path.join(args.save_dir, f"{os.path.basename(args.model1_path)}_vs_{os.path.basename(args.model2_path)}.json"), "w") as f:
        json.dump(results, f, indent=4)

    logger.info(f"Average distance between models: {average_distance}")


if __name__ == "__main__":
    main()
