import os

os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "0"
os.environ["DS_SKIP_CUDA_CHECK"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HOME"] = "/home/user/huggingface"

import argparse
import json

from transformers import AutoTokenizer


def main(args):
    # Build tokenizer
    ckpt = os.path.join(args.base_dir, "sft-gpt2", "checkpoint-50000")
    tokenizer = AutoTokenizer.from_pretrained(ckpt)

    # Define result directory
    if args.star_iter == 0:
        result_dir = ckpt
    elif args.star_iter == 1 or args.star_iter == 2:
        exp_name = f"star{args.star_iter}-final-{args.mode}-s{args.seed}-gpt2"
        result_dir = os.path.join(args.base_dir, exp_name, "checkpoint-20000")
    else:
        raise ValueError(f"Invalid star_iter: {args.star_iter}")
    start = args.start
    end = args.start + args.offset

    # Load base results
    base_result_file = f"results_final_s{args.seed}_{args.split}_{start}_{end}.json"
    base_result_file = os.path.join(result_dir, base_result_file)
    with open(base_result_file, "r") as f:
        base_results = json.load(f)
        base_trajectories = base_results["trajectories"]
        base_ratings = base_results["ratings"]

    # Load new results
    new_result_file = (
        f"results_final_{args.mode}_s{args.seed}_depth2_{args.split}_{start}_{end}.json"
    )
    new_result_file = os.path.join(result_dir, new_result_file)
    with open(new_result_file, "r") as f:
        new_results = json.load(f)
        new_trajectories = new_results["trajectories"]
        new_ratings = new_results["ratings"]

    # Check the number of trajectories
    assert len(base_trajectories) == len(new_trajectories)

    # Filter base trajectories
    base_trajectories = [
        base_trajectories[i]
        for i in range(len(base_trajectories))
        if base_ratings[i] > 0.0
    ]

    # Filter new trajectories
    new_trajectories = [
        new_trajectories[i]
        for i in range(len(new_trajectories))
        if base_ratings[i] == 0.0 and new_ratings[i] > 0.0
    ]

    # Compute base lengths
    base_lengths = []
    for base_trajectory in base_trajectories:
        base_length = len(tokenizer(base_trajectory)["input_ids"])
        base_lengths.append(base_length)

    # Compute new lengths
    new_lengths = []
    for new_trajectory in new_trajectories:
        new_length = len(tokenizer(new_trajectory)["input_ids"])
        new_lengths.append(new_length)

    # Save lengths
    lengths = {"base_lengths": base_lengths, "new_lengths": new_lengths}
    length_file = (
        f"lengths_final_{args.mode}_s{args.seed}_{args.split}_{start}_{end}.json"
    )
    length_file = os.path.join(result_dir, length_file)
    with open(length_file, "w") as f:
        json.dump(lengths, f, indent=4)


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_dir", required=True, type=str)
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--star_iter", default=0, type=int)
    parser.add_argument("--mode", default="rand", type=str)
    parser.add_argument("--split", default="train", type=str)
    parser.add_argument("--start", default=0, type=int)
    parser.add_argument("--offset", default=1000, type=int)
    args = parser.parse_args()

    # Run main
    main(args)
