#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Memory Bank Filter Script

过滤掉指定 iteration 的 embedding 和 question 数据，保存到原文件。

Usage:
    python memory_bank/filter_memory.py \
        --embedding_path storage/memory_bank/embedding_code.npy \
        --question_path storage/memory_bank/question_code.json \
        --filtered_iteration 1
"""

import argparse
import json
import numpy as np
from typing import List, Tuple


def load_data(embedding_path: str, question_path: str) -> Tuple[np.ndarray, List[dict]]:
    """加载 embedding 和 question 数据"""
    embeddings = np.load(embedding_path)
    with open(question_path, 'r') as f:
        questions = json.load(f)
    return embeddings, questions


def filter_by_iteration(
    embeddings: np.ndarray,
    questions: List[dict],
    filtered_iteration: int
) -> Tuple[np.ndarray, List[dict]]:
    """过滤掉指定 iteration 的数据，保留其他 iteration"""
    keep_indices = []
    filtered_questions = []
    
    for i, q in enumerate(questions):
        if q.get('iteration') != filtered_iteration:
            keep_indices.append(i)
            filtered_questions.append(q)
    
    if not keep_indices:
        embed_dim = embeddings.shape[1] if embeddings.ndim > 1 else 0
        return np.array([]).reshape(0, embed_dim), []
    
    filtered_embeddings = embeddings[keep_indices]
    return filtered_embeddings, filtered_questions


def save_data(
    embeddings: np.ndarray,
    questions: List[dict],
    embedding_path: str,
    question_path: str
):
    """保存过滤后的数据到原路径"""
    np.save(embedding_path, embeddings)
    with open(question_path, 'w') as f:
        json.dump(questions, f, indent=2, ensure_ascii=False)


def main():
    parser = argparse.ArgumentParser(description="Filter Memory Bank by iteration")
    parser.add_argument("--embedding_path", type=str, required=True,
                        help="Path to embedding .npy file")
    parser.add_argument("--question_path", type=str, required=True,
                        help="Path to question .json file")
    parser.add_argument("--filtered_iteration", type=int, required=True,
                        help="Iteration to filter out (remove)")
    
    args = parser.parse_args()
    
    # Step 1: 加载数据
    print(f"[Filter] Loading data...")
    print(f"[Filter]   Embedding: {args.embedding_path}")
    print(f"[Filter]   Question: {args.question_path}")
    embeddings, questions = load_data(args.embedding_path, args.question_path)
    
    print(f"[Filter] Loaded {len(questions)} questions, embeddings shape: {embeddings.shape}")
    
    # Step 2: 验证数量一致
    assert len(questions) == embeddings.shape[0], \
        f"Mismatch: {len(questions)} questions vs {embeddings.shape[0]} embeddings"
    
    # Step 3: 统计要过滤的数量
    count_to_remove = sum(1 for q in questions if q.get('iteration') == args.filtered_iteration)
    print(f"[Filter] Found {count_to_remove} items with iteration={args.filtered_iteration}")
    
    if count_to_remove == 0:
        print(f"[Filter] No items to remove. Exiting without changes.")
        return
    
    # Step 4: 过滤
    print(f"[Filter] Filtering out iteration {args.filtered_iteration}...")
    filtered_embeddings, filtered_questions = filter_by_iteration(
        embeddings, questions, args.filtered_iteration
    )
    
    print(f"[Filter] After filtering: {len(filtered_questions)} questions, embeddings shape: {filtered_embeddings.shape}")
    
    # Step 5: 保存
    print(f"[Filter] Saving filtered data to original paths...")
    save_data(filtered_embeddings, filtered_questions, args.embedding_path, args.question_path)
    
    print(f"[Filter] Done! Removed {count_to_remove} items from iteration {args.filtered_iteration}")


if __name__ == "__main__":
    main()

