#!/usr/bin/env python
import os
import argparse
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, AutoModelForCausalLM
import sys

def main(args):
    # Path to checkpoint file (assumes the checkpoint is in 'rewardModel.pt')
    checkpoint_path = os.path.join(args.input_dir, "rewardModel.pt")
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")

    # Load state dict from the checkpoint
    print(f"Loading state dictionary from {checkpoint_path}...")
    state_dict = torch.load(checkpoint_path, map_location="cpu")

    # Load configuration from a known model
    print(f"Loading model configuration from {args.base_model}...")
    config = AutoConfig.from_pretrained(args.base_model)
    # Force num_labels=2 to keep a standard classification head
    config.num_labels = 2

    print("State dict keys:")
    for key in state_dict.keys():
        print(key)

    # Initialize the model with the loaded configuration.
    print("Initializing model...")
    model = AutoModelForSequenceClassification.from_config(config)

    # Load the state dictionary into the model.
    print("Loading state dict into model...")
    missing_keys, unexpected_keys = model.load_state_dict(state_dict['state'], strict=False)
    if missing_keys:
        print("Warning: the following keys were missing:", missing_keys)
    if unexpected_keys:
        print("Warning: the following unexpected keys were found:", unexpected_keys)

    # Now override the final "score" layer with shape (2, hidden_dim) instead of (1, hidden_dim).
    # We'll place zeros in row 0 and the "real" scalar weights in row 1.
    hidden_size = model.score.in_features  # e.g. 4096

    # Create a brand-new linear layer for (out_features=2).
    # Make sure bias=True if you want a bias. If you do not want a bias, pass bias=False accordingly.
    new_score = nn.Linear(hidden_size, 2, bias=True)

    # "old_weight" and "old_bias" come from your single-scalar layer in the checkpoint.
    # For example, if previously you had nn.Linear(4096, 1, bias=True),
    # then old_weight shape is [1, 4096] and old_bias shape is [1].
    # Depending on how your checkpoint stored them, you may need to reshape or rename them.
    old_weight = state_dict['state']['lm_head.linear.weight']  # shape [4096] or [1,4096] depending on how it was saved
    old_bias   = state_dict['state']['lm_head.linear.bias']    # shape [4096] or [1] ...
    
    # Make sure they're shaped consistently as [1, hidden_size] and [1].
    if old_weight.dim() == 1:
        old_weight = old_weight.unsqueeze(0)   # => [1, 4096]
    if old_bias.dim() == 0:
        old_bias = old_bias.unsqueeze(0)       # => [1]

    with torch.no_grad():
        # Row 0 (index=0) => all zeros
        new_score.weight[0].fill_(0.0)
        new_score.bias[0].fill_(0.0)

        # Row 1 (index=1) => your real reward weights/bias
        # old_weight is shape [1, hidden_size]
        # old_bias   is shape [1]
        new_score.weight[1] = old_weight[0]
        new_score.bias[1]   = old_bias[0]

    # Assign the new layer
    model.score = new_score

    # Save the model in HuggingFace format.
    print(f"Saving model to {args.output_dir} in HuggingFace format...")
    os.makedirs(args.output_dir, exist_ok=True)
    model.save_pretrained(args.output_dir)

    # Save the tokenizer from the base model.
    print(f"Saving tokenizer from {args.base_model} to {args.output_dir}...")
    tokenizer = AutoTokenizer.from_pretrained(args.base_model)
    tokenizer.save_pretrained(args.output_dir)

    print("Model conversion complete.")



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert a PyTorch checkpoint to HuggingFace format.")
    parser.add_argument("--input_dir", type=str, required=True,
                        help="Directory containing the original checkpoint files (e.g., policy.pt).")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Directory to save the HuggingFace-formatted model.")
    parser.add_argument("--base_model", type=str, required=True,
                        help="A HuggingFace model identifier (or local path) used to load the configuration and tokenizer (e.g., 'meta-llama/Llama-3.1-8B-Instruct').")
    
    # If running without CLI arguments (e.g., via VSCode's run button), inject defaults.
    if len(sys.argv) == 1:
        sys.argv.extend([
            "--input_dir", "/mnt/raid10/amir/.cache/huggingface/amir/reward_infly_ordinal_symmetric_2025-05-08_07-52-10_892845/LATEST/",
            "--output_dir", "/mnt/raid10/amir/.cache/huggingface/amir/hs2_scaledBT01_llama3.1_2025-03-23_21-23-46_463835/step-34944/hf_test",
            "--base_model", "meta-llama/Llama-3.1-8B-Instruct"
        ])
    
    args = parser.parse_args()
    main(args)
