import argparse
import os
from string import Template

import torch
from huggingface_hub import snapshot_download
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer

parser = argparse.ArgumentParser()

parser.add_argument("--base_model", type=str, required=True, help="Path to HF checkpoint with the base model")

parser.add_argument(
    "--checkpoint",
    type=str,
    required=True,
    help="Path to either a local directory or a HF checkpoint with reward model's weights",
)

parser.add_argument("--revision", type=str, required=False, help="Optional branch/commit of the HF checkpoint")

parser.add_argument("--device", type=int, default=0)
args = parser.parse_args()

model_name = args.checkpoint.split("/")[-1]
device = torch.device(args.device)


class RewardModel(nn.Module):
    def __init__(self, checkpoint_path, eos_token_id):
        super().__init__()
        model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
        self.transformer = model.transformer
        self.v_head = nn.Linear(model.config.n_embd, 1, bias=False)
        self.eos_token_id = eos_token_id

    def forward(self, input_ids):
        states = self.transformer(input_ids)[0]
        rewards = self.v_head(states).squeeze(-1)
        ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1)
        returns = torch.gather(rewards, 1, ends).squeeze(-1)
        return returns


if os.path.isdir(args.checkpoint):
    directory = args.checkpoint
else:
    directory = snapshot_download(args.checkpoint, revision=args.revision)

print(f"searching through {os.listdir(directory)} in {directory}")

for fpath in os.listdir(directory):
    if fpath.endswith(".pt") or fpath.endswith(".bin"):
        checkpoint = os.path.join(directory, fpath)
        break

tokenizer = AutoTokenizer.from_pretrained(args.base_model)
model = RewardModel(args.base_model, tokenizer.eos_token_id)
model.load_state_dict(torch.load(checkpoint))
model.eval()
model.requires_grad_(False)
model = model.half().to(device)

input = tokenizer("reward model's hash", return_tensors="pt").to(device)
print(f"{model(input.input_ids)=}")

traced_script_module = torch.jit.trace(model, input.input_ids)

os.makedirs(f"model_store/{model_name}/1", exist_ok=True)
traced_script_module.save(f"model_store/{model_name}/1/traced-model.pt")

config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt")
with open(config_path) as f:
    template = Template(f.read())
config = template.substitute({"model_name": model_name})
with open(f"model_store/{model_name}/config.pbtxt", "w") as f:
    f.write(config)
