#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Experience Replay - Pre-Evaluation Mixing (前置混合)

This script samples historical questions from Memory Bank and merges them with 
newly generated questions BEFORE the evaluation step. This allows the current 
Solver to re-generate pseudo-labels for historical questions.

Key Features:
- Samples historical questions based on specified strategy
- Merges them into the generated question files (before evaluate.py runs)
- Marks historical data for tracking purposes

Usage:
    python memory_bank/sample_for_replay.py \
        --experiment_name qwen3-4b_solver_v2 \
        --iteration 2 \
        --replay_ratio 0.3 \
        --sampling_strategy uniform \
        --model_abbr qwen3-4b
"""

import argparse
import json
import os
import random
from typing import List, Dict, Optional
from collections import defaultdict


def load_memory_bank(memory_bank_path: str, embedding_type: str = "nl") -> List[Dict]:
    """
    Load questions from Memory Bank.
    
    Args:
        memory_bank_path: Path to memory bank directory
        embedding_type: "nl" for natural language, "code" for code embeddings
        
    Returns:
        List of question dictionaries
    """
    # Match the file naming convention in update_memory.py
    if embedding_type == "code":
        questions_file = "question_code.json"
    else:
        questions_file = "questions.json"
    
    questions_path = os.path.join(memory_bank_path, questions_file)
    
    if os.path.exists(questions_path):
        print(f"[Pre-Eval Replay] Loading from {questions_file}")
        with open(questions_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    print(f"[Pre-Eval Replay] File not found: {questions_path}")
    return []


def sample_uniform(historical: List[Dict], n_samples: int) -> List[Dict]:
    """Uniform random sampling."""
    return random.sample(historical, min(n_samples, len(historical)))


def sample_stratified(historical: List[Dict], n_samples: int, current_iteration: int) -> List[Dict]:
    """
    Stratified sampling: more recent iterations get higher weight.
    
    Weight distribution:
    - Previous iteration (i-1): 50%
    - Two iterations ago (i-2): 30%  
    - Older iterations: 20%
    """
    if not historical:
        return []
    
    # Group by iteration
    by_iteration = defaultdict(list)
    for q in historical:
        iter_num = q.get('iteration', 0)
        by_iteration[iter_num].append(q)
    
    # Define weights
    weight_map = {
        current_iteration - 1: 0.5,
        current_iteration - 2: 0.3,
    }
    older_weight = 0.2
    
    samples = []
    
    # Sample from recent iterations
    for iter_num, weight in weight_map.items():
        if iter_num in by_iteration:
            candidates = by_iteration[iter_num]
            n = int(n_samples * weight)
            samples.extend(random.sample(candidates, min(n, len(candidates))))
    
    # Sample from older iterations
    older_candidates = []
    for iter_num, questions in by_iteration.items():
        if iter_num < current_iteration - 2:
            older_candidates.extend(questions)
    
    if older_candidates:
        n_older = int(n_samples * older_weight)
        samples.extend(random.sample(older_candidates, min(n_older, len(older_candidates))))
    
    # If we don't have enough, sample more uniformly
    remaining = n_samples - len(samples)
    if remaining > 0:
        used_ids = {id(s) for s in samples}
        available = [q for q in historical if id(q) not in used_ids]
        if available:
            samples.extend(random.sample(available, min(remaining, len(available))))
    
    return samples[:n_samples]


def sample_recent_first(historical: List[Dict], n_samples: int) -> List[Dict]:
    """Sample prioritizing most recent iterations."""
    sorted_data = sorted(historical, key=lambda x: x.get('iteration', 0), reverse=True)
    return sorted_data[:n_samples]


def sample_score_weighted(historical: List[Dict], n_samples: int) -> List[Dict]:
    """
    Score-weighted sampling: higher score samples have higher probability.
    
    Score represents the quality/consistency of the question-answer pair.
    Higher score means better quality, so we want to sample those more often.
    
    Uses softmax-like normalization to convert scores to probabilities.
    """
    import numpy as np
    
    if not historical:
        return []
    
    # Extract scores, default to 0.5 if not present
    scores = np.array([q.get('score', 0.5) for q in historical])
    
    # Normalize scores to be positive (in case of negative scores)
    # and apply temperature scaling for better distribution
    min_score = scores.min()
    if min_score < 0:
        scores = scores - min_score + 0.1  # Shift to positive
    
    # Avoid zero scores
    scores = np.maximum(scores, 0.01)
    
    # Convert to probabilities (normalize)
    probabilities = scores / scores.sum()
    
    # Sample without replacement
    n_to_sample = min(n_samples, len(historical))
    
    try:
        sampled_indices = np.random.choice(
            len(historical), 
            size=n_to_sample, 
            replace=False, 
            p=probabilities
        )
        return [historical[i] for i in sampled_indices]
    except ValueError:
        # Fallback to uniform if probabilities are invalid
        return random.sample(historical, n_to_sample)


def sample_historical(
    memory_bank: List[Dict], 
    n_samples: int, 
    current_iteration: int,
    strategy: str = "uniform"
) -> List[Dict]:
    """
    Sample historical questions for replay.
    
    Args:
        memory_bank: All historical questions
        n_samples: Number of questions to sample
        current_iteration: Current iteration (to exclude)
        strategy: Sampling strategy ("uniform", "stratified", "recent_first", "score_weighted")
    
    Returns:
        List of sampled questions formatted for evaluation
    """
    # Filter out current iteration data (shouldn't exist, but just in case)
    historical = [q for q in memory_bank if q.get('iteration', 0) < current_iteration]
    
    if not historical:
        return []
    
    # Sample based on strategy
    if strategy == "stratified":
        sampled = sample_stratified(historical, n_samples, current_iteration)
    elif strategy == "recent_first":
        sampled = sample_recent_first(historical, n_samples)
    elif strategy == "score_weighted":
        sampled = sample_score_weighted(historical, n_samples)
    else:  # uniform
        sampled = sample_uniform(historical, n_samples)
    
    # Format for evaluate.py: needs 'question', 'answer', 'score' fields
    # Set score=0 to indicate "needs evaluation"
    formatted = []
    for q in sampled:
        formatted.append({
            "question": q["question"],
            "answer": q["answer"],
            "score": 0,  # Will be re-evaluated by current Solver
            "_from_memory": True,  # Mark as replay data (internal tracking)
            "_original_iteration": q.get("iteration", 0)
        })
    
    return formatted


def count_new_questions(storage_path: str, experiment_name: str) -> int:
    """Count newly generated questions across all GPU files."""
    # 动态获取GPU数量，支持6卡/8卡等不同配置
    n_gpus = int(os.getenv("TOTAL_GPU_COUNT", "8"))
    
    total = 0
    for i in range(n_gpus):
        file_path = os.path.join(storage_path, "generated_question", f"{experiment_name}_{i}.json")
        if os.path.exists(file_path):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                    total += len(data)
            except (json.JSONDecodeError, IOError):
                continue
    return total


def append_to_generated_files(
    storage_path: str, 
    experiment_name: str, 
    sampled_questions: List[Dict]
) -> None:
    """
    Append sampled historical questions to the generated question files.
    Distributes evenly across GPU files.
    
    Args:
        storage_path: Base storage path
        experiment_name: Name of the experiment
        sampled_questions: List of sampled questions to append
    """
    if not sampled_questions:
        return
    
    # 动态获取GPU数量，支持6卡/8卡等不同配置
    n_gpus = int(os.getenv("TOTAL_GPU_COUNT", "8"))
    
    # Shuffle before distributing
    random.shuffle(sampled_questions)
    
    # Distribute samples across GPU files
    samples_per_file = len(sampled_questions) // n_gpus
    remainder = len(sampled_questions) % n_gpus
    
    idx = 0
    for i in range(n_gpus):
        file_path = os.path.join(storage_path, "generated_question", f"{experiment_name}_{i}.json")
        
        # Load existing data
        if os.path.exists(file_path):
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
            except (json.JSONDecodeError, IOError):
                data = []
        else:
            data = []
        
        # Calculate how many samples to add to this file
        n_to_add = samples_per_file + (1 if i < remainder else 0)
        
        # Append samples
        data.extend(sampled_questions[idx:idx + n_to_add])
        idx += n_to_add
        
        # Save back
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=4, ensure_ascii=False)
        
        print(f"[Pre-Eval Replay] Added {n_to_add} historical questions to {experiment_name}_{i}.json")


def main():
    parser = argparse.ArgumentParser(
        description="Sample historical questions from Memory Bank for pre-eval experience replay"
    )
    parser.add_argument("--experiment_name", type=str, required=True,
                        help="Name of the experiment")
    parser.add_argument("--iteration", type=int, required=True,
                        help="Current iteration number")
    parser.add_argument("--replay_ratio", type=float, default=0.3,
                        help="Ratio of historical data to sample (relative to new data)")
    parser.add_argument("--sampling_strategy", type=str, default="uniform",
                        choices=["uniform", "stratified", "recent_first", "score_weighted"],
                        help="Sampling strategy for historical data")
    parser.add_argument("--model_abbr", type=str, default=None,
                        help="Model abbreviation for experiment isolation")
    parser.add_argument("--embedding_type", type=str, default="nl",
                        choices=["nl", "code"],
                        help="Embedding type: 'nl' for natural language, 'code' for code")
    args = parser.parse_args()
    
    storage_path = os.getenv("STORAGE_PATH")
    # Also check environment variable as fallback
    embedding_type = args.embedding_type or os.getenv("EMBEDDING_TYPE")
    
    # Build Memory Bank path
    if args.model_abbr:
        memory_bank_path = os.path.join(storage_path, "memory_bank", args.model_abbr)
    else:
        memory_bank_path = os.path.join(storage_path, "memory_bank")
    
    print("=" * 70)
    print("[Pre-Eval Replay] Experience Replay - Pre-Evaluation Mixing")
    print("=" * 70)
    print(f"  Experiment: {args.experiment_name}")
    print(f"  Iteration: {args.iteration}")
    print(f"  Replay ratio: {args.replay_ratio}")
    print(f"  Sampling strategy: {args.sampling_strategy}")
    print(f"  Embedding type: {embedding_type}")
    print(f"  Memory Bank path: {memory_bank_path}")
    print("=" * 70)
    
    # Skip for first iteration (no historical data)
    if args.iteration <= 1:
        print("[Pre-Eval Replay] First iteration, no historical data to replay. Skipping.")
        return
    
    # Load Memory Bank
    memory_bank = load_memory_bank(memory_bank_path, embedding_type)
    print(f"[Pre-Eval Replay] Memory Bank contains {len(memory_bank)} total questions")
    
    if not memory_bank:
        print("[Pre-Eval Replay] Memory Bank is empty. Skipping.")
        return
    
    # Count new questions
    n_new = count_new_questions(storage_path, args.experiment_name)
    print(f"[Pre-Eval Replay] Newly generated questions: {n_new}")
    
    if n_new == 0:
        print("[Pre-Eval Replay] No new questions found. Skipping.")
        return
    
    # Calculate replay sample size
    n_replay = int(n_new * args.replay_ratio)
    print(f"[Pre-Eval Replay] Target replay samples: {n_replay} (ratio={args.replay_ratio})")
    
    # Sample from Memory Bank
    sampled = sample_historical(
        memory_bank, 
        n_replay, 
        args.iteration, 
        args.sampling_strategy
    )
    print(f"[Pre-Eval Replay] Actually sampled: {len(sampled)} questions")
    
    if not sampled:
        print("[Pre-Eval Replay] No samples available. Skipping.")
        return
    
    # Show sampling statistics
    iteration_counts = defaultdict(int)
    for s in sampled:
        iteration_counts[s.get('_original_iteration', 0)] += 1
    print(f"[Pre-Eval Replay] Sample distribution by iteration:")
    for iter_num in sorted(iteration_counts.keys()):
        print(f"    Iteration {iter_num}: {iteration_counts[iter_num]} questions")
    
    # Append to generated files
    append_to_generated_files(storage_path, args.experiment_name, sampled)
    
    print("=" * 70)
    print(f"[Pre-Eval Replay] Complete! {len(sampled)} historical questions added.")
    print(f"[Pre-Eval Replay] Total questions for evaluation: {n_new + len(sampled)}")
    print("=" * 70)


if __name__ == "__main__":
    main()
