import argparse
import collections
import math
import time
from typing import List
import os
from collections import Counter

import torch
from calib.utils import extract_answer, math_equal
from transformers import AutoTokenizer

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True, help="Path to the completion file to shorten")
    parser.add_argument("--max_completion_length", type=int, default=1024)
    args = parser.parse_args()

    # Check if file exists
    if not os.path.exists(args.input_file):
        print(f"Error: File {args.input_file} does not exist")
        return

    # Load the checkpoint
    ckpt = torch.load(args.input_file, map_location="cpu", weights_only=False)
    completion_ids = ckpt["completion_ids"]
    N = len(completion_ids)
    K = len(completion_ids[0])
    completion_ids = [[completion_ids[i][j][:args.max_completion_length] for j in range(K)] for i in range(N)]
    ckpt["completion_ids"] = completion_ids

    model_name = ckpt["args"].model_name
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    completions = [tokenizer.batch_decode(completion_ids[i], skip_special_tokens=True) for i in range(N)]
    completion_lengths = [[len(completion) for completion in completion_ids[i]] for i in range(N)]
    ckpt["completions"] = completions
    ckpt["completion_lengths"] = completion_lengths

    responses = [[extract_answer(completion) for completion in completions[i]] for i in range(N)]
    ckpt["responses"] = responses
    answers = ckpt["answers"]
    rewards = [[math_equal(responses[i][j], answers[i]) for j in range(len(responses[i]))] for i in range(N)]
    ckpt["rewards"] = rewards
    
    ckpt["args"].max_completion_length = args.max_completion_length
    
    # Generate output filename
    base_name = os.path.splitext(args.input_file)[0]
    if base_name.endswith("4096"):
        new_file = base_name.replace("4096", f"{args.max_completion_length}") + ".pt"
    else:
        new_file = f"{base_name}_{args.max_completion_length}.pt"
    
    torch.save(ckpt, new_file)
    print(f"Shortened completion saved to: {new_file}")

if __name__ == "__main__":
    main()
