import torch

def load_steering_vec(vec_path, layer_num=49):
    all_data = torch.load(vec_path, weights_only=True)
    exec = [[] for _ in range(layer_num)]
    ref = [[] for _ in range(layer_num)]
    tran = [[] for _ in range(layer_num)]
    for i in range(49):
        layer_data = all_data[i]
        for k in range(len(layer_data)):
            exec_h = layer_data[k]["exec"]
            exec[i].append(exec_h)
            ref_h = layer_data[k]["ref"]
            ref[i].append(ref_h)
            tran_h = layer_data[k]["tran"]
            tran[i].append(tran_h)
    for l in range(layer_num):
        exec[l] = torch.cat(exec[l], dim=0).mean(dim=0)
        ref[l] = torch.cat(ref[l], dim=0).mean(dim=0)
        tran[l] = torch.cat(tran[l], dim=0).mean(dim=0)
    
    exec = torch.stack(exec, dim=0)
    ref = torch.stack(ref, dim=0)
    tran = torch.stack(tran, dim=0)
    return exec, ref, tran

exec, ref, tran = load_steering_vec("/path/to/hidden_states/hidden.pt")
for layer in range(49):
    layer_exec = exec[layer]
    layer_ref = ref[layer]
    layer_tran = tran[layer]
    steer_vec = layer_exec.mean(dim=0) - torch.cat([layer_ref, layer_tran], dim=0).mean(dim=0)
    save_path = f"/path/to/steering_vector/layer_{layer}_steering_vec.pt"
    with open(save_path, "wb") as f:
        torch.save(steer_vec, f)