import os
import torch
import argparse

from tqdm import tqdm

def load_and_fix_shards(base_path, split, num_shards):
    all_embeddings = []
    all_masks = []
    all_labels = []
    all_prompts = []
    all_responses = []
    all_scores = []

    for shard_id in tqdm(range(num_shards), desc=f"Loading {split} shards"):
        emb_path = os.path.join(base_path, f"{split}_embeddings_shard{shard_id}.pt")
        mask_path = os.path.join(base_path, f"{split}_masks_shard{shard_id}.pt")
        label_path = os.path.join(base_path, f"{split}_labels_shard{shard_id}.pt")
        prompt_path = os.path.join(base_path, f"{split}_prompts_shard{shard_id}.pt")
        response_path = os.path.join(base_path, f"{split}_responses_shard{shard_id}.pt")
        scores_path = os.path.join(base_path, f"{split}_scores_shard{shard_id}.pt")

        embeddings = torch.load(emb_path)
        masks = torch.load(mask_path)
        labels = torch.load(label_path)
        prompts = torch.load(prompt_path)
        responses = torch.load(response_path)
        scores = torch.load(scores_path)

        for i, (emb, mask, label) in enumerate(zip(embeddings, masks, labels)):
            emb_len = emb.shape[0]
            mask_len = mask.shape[0]

            # Ensure response-only mask aligns with embeddings
            if mask.sum().item() != emb_len:
                raise ValueError(
                    f"[Shard {shard_id}, Entry {i}] Mask has {mask.sum().item()} response tokens, "
                    f"but embedding has shape {emb.shape}"
                )

            # Keep only the trailing part of the mask (the response region)
            if mask_len > emb_len:
                mask = mask[-emb_len:]

            all_embeddings.append(emb.to(torch.bfloat16))
            all_masks.append(mask)
            all_labels.append(label)
            all_prompts.append(prompts[i])
            all_responses.append(responses[i])
            all_scores.append(1-scores[i])

    return all_embeddings, all_masks, all_labels, all_prompts, all_responses, all_scores

def save_stacked_data(embeddings, masks, labels, save_path, prompts, responses, scores, split):
    torch.save(embeddings, os.path.join(save_path, f"{split}_embeddings.pt"))
    torch.save(masks, os.path.join(save_path, f"{split}_masks.pt"))
    torch.save(labels, os.path.join(save_path, f"{split}_binary_is_safes.pt"))
    torch.save(labels, os.path.join(save_path, f"{split}_is_safes.pt"))
    torch.save(prompts, os.path.join(save_path, f"{split}_prompts.pt"))
    torch.save(responses, os.path.join(save_path, f"{split}_responses.pt"))
    torch.save(scores, os.path.join(save_path, f"{split}_scores.pt"))
    print(f"Saved stacked files to {save_path}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--split", type=str, required=True)
    parser.add_argument("--num_shards", type=int, required=True)
    parser.add_argument("--root_dir", type=str, required=True)

    args = parser.parse_args()

    base_path = f"{args.root_dir}/{args.dataset}_llm_generated/{args.model_name.replace('/', '_')}"
    embeddings, masks, labels, prompts, responses, scores = load_and_fix_shards(base_path, args.split, args.num_shards)
    save_stacked_data(embeddings, masks, labels, base_path, prompts, responses, scores, args.split)

if __name__ == "__main__":
    main()
