#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Experience Replay - Post-Evaluation Mixing (后置混合)

This script samples historical questions from Memory Bank and merges them with
the current iteration's filtered data AFTER evaluation. The historical data 
retains its original pseudo-labels (generated by previous Solvers).

Key Features:
- Samples historical questions based on specified strategy
- Loads current iteration data from HuggingFace
- Merges and uploads combined dataset for training

Usage:
    python memory_bank/mix_training_data.py \
        --experiment_name qwen3-4b_solver_v2 \
        --iteration 2 \
        --replay_ratio 0.3 \
        --sampling_strategy uniform \
        --model_abbr qwen3-4b
"""

import argparse
import json
import os
import random
import sys
import time
from typing import List, Dict, Optional
from collections import defaultdict
from datasets import Dataset, load_dataset
from huggingface_hub import login


# ============ HuggingFace 超时和重试配置 ============
HF_DOWNLOAD_TIMEOUT = 120  # 下载超时时间（秒）
HF_MAX_RETRIES = 5  # 最大重试次数
HF_RETRY_BASE_WAIT = 10  # 重试基础等待时间（秒），每次递增


def load_dataset_with_retry(
    repo_name: str,
    config_name: Optional[str] = None,
    split: str = "train",
    max_retries: int = HF_MAX_RETRIES,
    timeout: int = HF_DOWNLOAD_TIMEOUT
) -> Optional[Dataset]:
    """
    带重试机制的 HuggingFace 数据集加载。
    
    Args:
        repo_name: HuggingFace 仓库名（如 "username/dataset_name"）
        config_name: 配置名（可选）
        split: 数据集分片
        max_retries: 最大重试次数
        timeout: 超时时间（秒）
    
    Returns:
        Dataset 或 None（失败时）
    """
    # 设置超时环境变量
    old_timeout = os.environ.get('HF_HUB_DOWNLOAD_TIMEOUT')
    os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = str(timeout)
    
    try:
        for attempt in range(max_retries):
            try:
                print(f"[HuggingFace] Downloading (attempt {attempt + 1}/{max_retries}, timeout={timeout}s)...")
                
                if config_name:
                    dataset = load_dataset(repo_name, name=config_name, split=split)
                else:
                    dataset = load_dataset(repo_name, split=split)
                
                print(f"[HuggingFace] Download successful!")
                return dataset
                
            except Exception as e:
                error_msg = str(e).lower()
                is_timeout = "timeout" in error_msg or "timed out" in error_msg
                
                if attempt < max_retries - 1:
                    wait_time = HF_RETRY_BASE_WAIT * (attempt + 1)  # 递增等待
                    if is_timeout:
                        print(f"[HuggingFace] Timeout error, retrying in {wait_time}s...")
                    else:
                        print(f"[HuggingFace] Error: {e}")
                        print(f"[HuggingFace] Retrying in {wait_time}s...")
                    time.sleep(wait_time)
                else:
                    print(f"[HuggingFace] Failed after {max_retries} attempts: {e}")
                    return None
        
        return None
    finally:
        # 恢复原始环境变量
        if old_timeout is not None:
            os.environ['HF_HUB_DOWNLOAD_TIMEOUT'] = old_timeout
        elif 'HF_HUB_DOWNLOAD_TIMEOUT' in os.environ:
            del os.environ['HF_HUB_DOWNLOAD_TIMEOUT']


def push_to_hub_with_retry(
    dataset: Dataset,
    repo_name: str,
    config_name: Optional[str] = None,
    private: bool = True,
    max_retries: int = HF_MAX_RETRIES
) -> bool:
    """
    带重试机制的 HuggingFace 数据集上传。
    
    Args:
        dataset: 要上传的数据集
        repo_name: HuggingFace 仓库名
        config_name: 配置名（可选）
        private: 是否设为私有仓库
        max_retries: 最大重试次数
    
    Returns:
        是否上传成功
    """
    for attempt in range(max_retries):
        try:
            print(f"[HuggingFace] Uploading (attempt {attempt + 1}/{max_retries})...")
            
            if config_name:
                dataset.push_to_hub(repo_name, private=private, config_name=config_name)
            else:
                dataset.push_to_hub(repo_name, private=private)
            
            print(f"[HuggingFace] Upload successful!")
            return True
            
        except Exception as e:
            if attempt < max_retries - 1:
                wait_time = HF_RETRY_BASE_WAIT * (attempt + 1)
                print(f"[HuggingFace] Upload error: {e}")
                print(f"[HuggingFace] Retrying in {wait_time}s...")
                time.sleep(wait_time)
            else:
                print(f"[HuggingFace] Upload failed after {max_retries} attempts: {e}")
                return False
    
    return False


def load_memory_bank(memory_bank_path: str, embedding_type: str = "nl") -> List[Dict]:
    """
    Load questions from Memory Bank.
    
    Args:
        memory_bank_path: Path to memory bank directory
        embedding_type: "nl" for natural language, "code" for code embeddings
        
    Returns:
        List of question dictionaries
    """
    # Match the file naming convention in update_memory.py
    if embedding_type == "code":
        questions_file = "question_code.json"
    else:
        questions_file = "questions.json"
    
    questions_path = os.path.join(memory_bank_path, questions_file)
    
    if os.path.exists(questions_path):
        print(f"[Post-Eval Replay] Loading from {questions_file}")
        with open(questions_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    
    print(f"[Post-Eval Replay] File not found: {questions_path}")
    return []


def sample_uniform(historical: List[Dict], n_samples: int) -> List[Dict]:
    """Uniform random sampling."""
    return random.sample(historical, min(n_samples, len(historical)))


def sample_stratified(historical: List[Dict], n_samples: int, current_iteration: int) -> List[Dict]:
    """
    Stratified sampling: more recent iterations get higher weight.
    
    Weight distribution:
    - Previous iteration (i-1): 50%
    - Two iterations ago (i-2): 30%  
    - Older iterations: 20%
    """
    if not historical:
        return []
    
    # Group by iteration
    by_iteration = defaultdict(list)
    for q in historical:
        iter_num = q.get('iteration', 0)
        by_iteration[iter_num].append(q)
    
    # Define weights
    weight_map = {
        current_iteration - 1: 0.5,
        current_iteration - 2: 0.3,
    }
    older_weight = 0.2
    
    samples = []
    
    # Sample from recent iterations
    for iter_num, weight in weight_map.items():
        if iter_num in by_iteration:
            candidates = by_iteration[iter_num]
            n = int(n_samples * weight)
            samples.extend(random.sample(candidates, min(n, len(candidates))))
    
    # Sample from older iterations
    older_candidates = []
    for iter_num, questions in by_iteration.items():
        if iter_num < current_iteration - 2:
            older_candidates.extend(questions)
    
    if older_candidates:
        n_older = int(n_samples * older_weight)
        samples.extend(random.sample(older_candidates, min(n_older, len(older_candidates))))
    
    # If we don't have enough, sample more uniformly
    remaining = n_samples - len(samples)
    if remaining > 0:
        used_ids = {id(s) for s in samples}
        available = [q for q in historical if id(q) not in used_ids]
        if available:
            samples.extend(random.sample(available, min(remaining, len(available))))
    
    return samples[:n_samples]


def sample_recent_first(historical: List[Dict], n_samples: int) -> List[Dict]:
    """Sample prioritizing most recent iterations."""
    sorted_data = sorted(historical, key=lambda x: x.get('iteration', 0), reverse=True)
    return sorted_data[:n_samples]


def sample_score_weighted(historical: List[Dict], n_samples: int) -> List[Dict]:
    """
    Score-weighted sampling: higher score samples have higher probability.
    
    Score represents the quality/consistency of the question-answer pair.
    Higher score means better quality, so we want to sample those more often.
    
    Uses softmax-like normalization to convert scores to probabilities.
    """
    import numpy as np
    
    if not historical:
        return []
    
    # Extract scores, default to 0.5 if not present
    scores = np.array([q.get('score', 0.5) for q in historical])
    
    # Normalize scores to be positive (in case of negative scores)
    # and apply temperature scaling for better distribution
    min_score = scores.min()
    if min_score < 0:
        scores = scores - min_score + 0.1  # Shift to positive
    
    # Avoid zero scores
    scores = np.maximum(scores, 0.01)
    
    # Convert to probabilities (normalize)
    probabilities = scores / scores.sum()
    
    # Sample without replacement
    n_to_sample = min(n_samples, len(historical))
    
    try:
        sampled_indices = np.random.choice(
            len(historical), 
            size=n_to_sample, 
            replace=False, 
            p=probabilities
        )
        return [historical[i] for i in sampled_indices]
    except ValueError:
        # Fallback to uniform if probabilities are invalid
        return random.sample(historical, n_to_sample)


def sample_historical(
    memory_bank: List[Dict], 
    n_samples: int, 
    current_iteration: int,
    strategy: str = "uniform"
) -> List[Dict]:
    """
    Sample historical questions for replay.
    
    Args:
        memory_bank: All historical questions
        n_samples: Number of questions to sample
        current_iteration: Current iteration (to exclude)
        strategy: Sampling strategy ("uniform", "stratified", "recent_first", "score_weighted")
    
    Returns:
        List of sampled questions
    """
    # Filter out current iteration data
    historical = [q for q in memory_bank if q.get('iteration', 0) < current_iteration]
    
    if not historical:
        return []
    
    # Sample based on strategy
    if strategy == "stratified":
        sampled = sample_stratified(historical, n_samples, current_iteration)
    elif strategy == "recent_first":
        sampled = sample_recent_first(historical, n_samples)
    elif strategy == "score_weighted":
        sampled = sample_score_weighted(historical, n_samples)
    else:  # uniform
        sampled = sample_uniform(historical, n_samples)
    
    return sampled


def main():
    parser = argparse.ArgumentParser(
        description="Mix historical data with current data for post-eval experience replay"
    )
    parser.add_argument("--experiment_name", type=str, required=True,
                        help="Name of the experiment")
    parser.add_argument("--iteration", type=int, required=True,
                        help="Current iteration number")
    parser.add_argument("--replay_ratio", type=float, default=0.3,
                        help="Ratio of historical data to sample (relative to current data)")
    parser.add_argument("--sampling_strategy", type=str, default="uniform",
                        choices=["uniform", "stratified", "recent_first", "score_weighted"],
                        help="Sampling strategy for historical data")
    parser.add_argument("--model_abbr", type=str, default=None,
                        help="Model abbreviation for experiment isolation")
    parser.add_argument("--embedding_type", type=str, default="nl",
                        choices=["nl", "code"],
                        help="Embedding type: 'nl' for natural language, 'code' for code")
    args = parser.parse_args()
    
    storage_path = os.getenv("STORAGE_PATH")
    huggingface_name = os.getenv("HUGGINGFACENAME")
    # Also check environment variable as fallback
    embedding_type = args.embedding_type or os.getenv("EMBEDDING_TYPE")
    
    # Build Memory Bank path
    if args.model_abbr:
        memory_bank_path = os.path.join(storage_path, "memory_bank", args.model_abbr)
    else:
        memory_bank_path = os.path.join(storage_path, "memory_bank")
    
    print("=" * 70)
    print("[Post-Eval Replay] Experience Replay - Post-Evaluation Mixing")
    print("=" * 70)
    print(f"  Experiment: {args.experiment_name}")
    print(f"  Iteration: {args.iteration}")
    print(f"  Replay ratio: {args.replay_ratio}")
    print(f"  Sampling strategy: {args.sampling_strategy}")
    print(f"  Embedding type: {embedding_type}")
    print(f"  Memory Bank path: {memory_bank_path}")
    print(f"  HuggingFace name: {huggingface_name}")
    print("=" * 70)
    
    # Skip for first iteration (no historical data)
    if args.iteration <= 1:
        print("[Post-Eval Replay] First iteration, no historical data to replay. Skipping.")
        return
    
    # Login to HuggingFace
    try:
        with open('tokens.json', 'r') as f:
            token = json.load(f)['huggingface']
        login(token=token)
    except Exception as e:
        print(f"[Post-Eval Replay] Warning: Failed to load HuggingFace token: {e}")
    
    # Load current iteration data from HuggingFace (with retry)
    print(f"[Post-Eval Replay] Loading current data from {huggingface_name}/{args.experiment_name}")
    
    # 首先尝试使用指定的 config_name
    current_dataset = load_dataset_with_retry(
        f"{huggingface_name}/{args.experiment_name}",
        config_name=args.experiment_name,
        split="train"
    )
    
    # 如果失败，尝试不使用 config_name
    if current_dataset is None:
        print("[Post-Eval Replay] Attempting to load with default config...")
        current_dataset = load_dataset_with_retry(
            f"{huggingface_name}/{args.experiment_name}",
            config_name=None,
            split="train"
        )
    
    # 如果仍然失败，退出并返回错误码
    if current_dataset is None:
        print("[Post-Eval Replay] ERROR: Failed to load current data after all retries!")
        print("[Post-Eval Replay] Stopping to prevent training with incorrect data.")
        sys.exit(1)
    
    current_data = [
        {
            "problem": item["problem"], 
            "answer": item["answer"], 
            "score": item.get("score", 0.5)
        }
        for item in current_dataset
    ]
    print(f"[Post-Eval Replay] Current iteration data: {len(current_data)} questions")
    
    if not current_data:
        print("[Post-Eval Replay] No current data found. Skipping.")
        return
    
    # Load Memory Bank
    memory_bank = load_memory_bank(memory_bank_path, embedding_type)
    print(f"[Post-Eval Replay] Memory Bank contains {len(memory_bank)} total questions")
    
    if not memory_bank:
        print("[Post-Eval Replay] Memory Bank is empty. Using only current data.")
        return
    
    # Calculate replay sample size
    n_current = len(current_data)
    n_replay = int(n_current * args.replay_ratio / (1 - args.replay_ratio))
    print(f"[Post-Eval Replay] Target replay samples: {n_replay}")
    
    # Sample from Memory Bank
    sampled = sample_historical(
        memory_bank, 
        n_replay, 
        args.iteration, 
        args.sampling_strategy
    )
    print(f"[Post-Eval Replay] Actually sampled: {len(sampled)} questions")
    
    if not sampled:
        print("[Post-Eval Replay] No samples available. Using only current data.")
        return
    
    # Show sampling statistics
    iteration_counts = defaultdict(int)
    for s in sampled:
        iteration_counts[s.get('iteration', 0)] += 1
    print(f"[Post-Eval Replay] Sample distribution by iteration:")
    for iter_num in sorted(iteration_counts.keys()):
        print(f"    Iteration {iter_num}: {iteration_counts[iter_num]} questions")
    
    # Convert sampled data to training format
    replay_data = [
        {
            "problem": q["question"], 
            "answer": q["answer"], 
            "score": q.get("score", 0.5)
        }
        for q in sampled
    ]
    
    # Merge data
    mixed_data = current_data + replay_data
    random.shuffle(mixed_data)
    
    print(f"[Post-Eval Replay] Mixed dataset size: {len(mixed_data)}")
    print(f"[Post-Eval Replay] Current data ratio: {len(current_data) / len(mixed_data):.2%}")
    print(f"[Post-Eval Replay] Replay data ratio: {len(replay_data) / len(mixed_data):.2%}")
    
    # Upload mixed dataset (with retry)
    print(f"[Post-Eval Replay] Uploading mixed dataset...")
    mixed_dataset = Dataset.from_list(mixed_data)
    
    upload_success = push_to_hub_with_retry(
        mixed_dataset,
        f"{huggingface_name}/{args.experiment_name}",
        config_name=f"{args.experiment_name}_mixed",
        private=True
    )
    
    if not upload_success:
        print("[Post-Eval Replay] ERROR: Failed to upload mixed dataset after all retries!")
        print("[Post-Eval Replay] Stopping to prevent training with incorrect data.")
        sys.exit(1)
    
    print("=" * 70)
    print(f"[Post-Eval Replay] Complete!")
    print(f"[Post-Eval Replay] Mixed dataset uploaded as config: {args.experiment_name}_mixed")
    print(f"[Post-Eval Replay] Training will use: {huggingface_name}/{args.experiment_name}@{args.experiment_name}_mixed")
    print("=" * 70)


if __name__ == "__main__":
    main()
