import os
import json
from itertools import product
import random
import argparse
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from tqdm import tqdm
from sampler import FaissRAGReasoningPreprocessor

def batch_generate_all(
    input_json_path: str,
    output_dir: str,
    model_name: str,
    tokenizer_name: str,
    neg_token_length: int = 128,
    cache_dir: str = "./cache",
    neg_strategy: str = "external_rag",
    anchor_source: str = "answer"
):

    # neg_strategies = ["random_vocab", "vocab_rag", "sample_rag", "qa_sampling", "sentence_token_sampling"]

    # neg_strategies = ["random_vocab", "qa_sampling", "sentence_token_sampling"]
    # anchor_sources = ["answer"]

    # neg_strategies = ["sentence_token_sampling"]
    # anchor_sources = ["answer"]

    # neg_strategies = ["vocab_rag", "sample_rag"]
    # anchor_sources = ["answer", "question", "reasoning_random"]

    # neg_strategies = ["vocab_rag", "sample_rag"]
    # anchor_sources = ["reasoning_random"]
    output_path_pos = os.path.join(output_dir, f"pos_{neg_strategy}_{anchor_source}.json")
    output_path_neg = os.path.join(output_dir, f"neg_{neg_strategy}_{anchor_source}.json")

    try:
        existing_files = os.listdir(output_dir)
        for fname in existing_files:
            if f"neg_{neg_strategy}_{anchor_source}" in fname:
                print(f"Matching file already exists: {fname}, skipping generation.")
                return
    except:
        print("jjj")

    neg_strategies = ["external_rag", "vocab_rag", "sample_rag"]
    anchor_sources = ["answer", "question", "reasoning_random"]

    # neg_strategies = ["external_rag"]
    # anchor_sources = ["answer", "question", "reasoning_random"]


    with open(input_json_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    print(f"🔵 Loaded {len(data)} samples from {input_json_path}")

    embedder = SentenceTransformer(model_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    vocab_indices = list(range(tokenizer.vocab_size))

    os.makedirs(output_dir, exist_ok=True)

    print(f"\nProcessing: neg_strategy={neg_strategy}, anchor_source={anchor_source}")

    # output_path_pos = os.path.join(output_dir, f"pos_{neg_strategy}_{anchor_source}.json")
    # output_path_neg = os.path.join(output_dir, f"neg_{neg_strategy}_{anchor_source}.json")

    # preprocessor = FaissRAGReasoningPreprocessor(
    #     data, embedder, tokenizer, vocab_indices,
    #     neg_token_length=neg_token_length,
    #     neg_strategy=neg_strategy,
    #     anchor_source=anchor_source
    # )
    # preprocessor.build_and_save(output_path_pos, mode="pos")

    preprocessor = FaissRAGReasoningPreprocessor(
        data, embedder, tokenizer, vocab_indices,
        neg_token_length=neg_token_length,
        neg_strategy=neg_strategy,
        anchor_source=anchor_source,
        cache_dir=cache_dir
    )
    preprocessor.build_and_save(output_path_neg, mode="neg")

    # for neg_strategy, anchor_source in product(neg_strategies, anchor_sources):
    #     output_path_pos = os.path.join(output_dir, f"pos_{neg_strategy}_{anchor_source}.json")
    #     output_path_neg = os.path.join(output_dir, f"neg_{neg_strategy}_{anchor_source}.json")

    #     print(f"\n Processing: neg_strategy={neg_strategy}, anchor_source={anchor_source}")

    #     # # # 1. positive samples
    #     # preprocessor = FaissRAGReasoningPreprocessor(
    #     #     data, embedder, tokenizer, vocab_indices,
    #     #     neg_token_length=neg_token_length,
    #     #     neg_strategy=neg_strategy,
    #     #     anchor_source=anchor_source
    #     # )
    #     # preprocessor.build_and_save(output_path_pos, mode="pos")

    #     # 2. negative samples
    #     preprocessor = FaissRAGReasoningPreprocessor(
    #         data, embedder, tokenizer, vocab_indices,
    #         neg_token_length=neg_token_length,
    #         neg_strategy=neg_strategy,
    #         anchor_source=anchor_source,
    #         cache_dir=cache_dir
    #     )
    #     preprocessor.build_and_save(output_path_neg, mode="neg")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_json", type=str, required=True, help="Path to input json file")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save all output files")
    parser.add_argument("--model_name", type=str, default="sentence-transformers/all-MiniLM-L6-v2", help="Embedding model name")
    parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="Tokenizer name")
    parser.add_argument("--neg_token_length", type=int, default=20, help="Negative token length")
    parser.add_argument("--cache_dir", type=str, default="./cache", help="Directory to cache embeddings and indices")
    parser.add_argument("--neg_strategy", type=str, default="external_rag", help="Negative sampling strategy")
    parser.add_argument("--anchor_source", type=str, default="answer", help="Anchor source type")

    args = parser.parse_args()

    batch_generate_all(
        input_json_path=args.input_json,
        output_dir=args.output_dir,
        model_name=args.model_name,
        tokenizer_name=args.tokenizer_name,
        neg_token_length=args.neg_token_length,
        cache_dir=args.cache_dir,
        neg_strategy=args.neg_strategy,
        anchor_source=args.anchor_source
    )