#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Memory Bank Update Script

This script updates the Memory Bank with filtered questions after each Solver training iteration.
It supports two modes:
    - nl (natural language): Uses BAAI/bge-large-en-v1.5 for text embeddings
    - code: Converts questions to Python code and uses jinaai/jina-code-embeddings-1.5b

Features:
    - Incremental update: appends new questions to existing Memory Bank
    - FIFO strategy: optionally removes questions from old iterations
    - Multi-GPU parallel code generation and embedding computation
    - Stores questions.json/question_code.json and embeddings.npy/embedding_code.npy

Usage:
    python memory_bank/update_memory.py \
        --experiment_name ${experiment_name} \
        --iteration ${iteration} \
        --embedding_type code \
        --max_score 0.8 \
        --min_score 0.3 \
        --max_iterations -1
"""

import argparse
import json
import os
import re
import subprocess
import sys
import textwrap
import time
import numpy as np
import torch
from typing import List, Dict, Tuple, Optional
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.prompts import QUESTION_TO_CODE_PROMPT


def get_available_gpus() -> List[int]:
    """Get list of available GPU indices."""
    if not torch.cuda.is_available():
        return []
    return list(range(torch.cuda.device_count()))


# ============ Natural Language Embedding ============

def load_nl_embedding_model(model_name: str, device: torch.device):
    """Load the NL embedding model onto the specified device."""
    from transformers import AutoTokenizer, AutoModel
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    model = model.to(device)
    model.eval()
    
    return tokenizer, model


def compute_nl_embeddings(
    texts: List[str],
    model_name: str = "BAAI/bge-large-en-v1.5",
    batch_size: int = 32
) -> np.ndarray:
    """
    Compute NL embeddings for a list of texts.
    """
    if not texts:
        return np.array([]).reshape(0, 1024)
    
    available_gpus = get_available_gpus()
    device = torch.device("cuda:0") if available_gpus else torch.device("cpu")
    print(f"[Memory Bank] Computing NL embeddings on {device}")
    
    tokenizer, model = load_nl_embedding_model(model_name, device)
    
    # Prefix for similarity search
    prefix = "Represent this question for similarity search: "
    prefixed_texts = [prefix + t for t in texts]
    
    all_embeddings = []
    
    with torch.no_grad():
        for i in range(0, len(prefixed_texts), batch_size):
            batch = prefixed_texts[i:i + batch_size]
            
            # Tokenize with max_length=512 as per specification
            encoded = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            encoded = {k: v.to(device) for k, v in encoded.items()}
            
            # Get embeddings from model
            outputs = model(**encoded)
            # Use [CLS] token embedding (first token)
            embeddings = outputs.last_hidden_state[:, 0, :]
            # Normalize embedding for cosine similarity (L2 normalization)
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            all_embeddings.append(embeddings.cpu().numpy())
    
    return np.vstack(all_embeddings)


# ============ Code Embedding ============

def load_code_embedding_model(model_name: str, device: torch.device):
    """Load the code embedding model onto the specified device."""
    from transformers import AutoTokenizer, AutoModel
    
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
    model = model.to(device)
    model.eval()
    
    return tokenizer, model


def compute_code_embeddings(
    codes: List[str],
    model_name: str = "jinaai/jina-code-embeddings-1.5b",
    batch_size: int = 16
) -> np.ndarray:
    """
    Compute code embeddings for a list of Python code snippets.
    """
    if not codes:
        return np.array([]).reshape(0, 1536)  # jina-code-embeddings dimension
    
    available_gpus = get_available_gpus()
    device = torch.device("cuda:0") if available_gpus else torch.device("cpu")
    print(f"[Memory Bank] Computing code embeddings on {device}")
    
    tokenizer, model = load_code_embedding_model(model_name, device)
    
    all_embeddings = []
    
    with torch.no_grad():
        for i in range(0, len(codes), batch_size):
            batch = codes[i:i + batch_size]
            
            encoded = tokenizer(
                batch,
                padding=True,
                truncation=True,
                max_length=1024,
                return_tensors="pt"
            )
            encoded = {k: v.to(device) for k, v in encoded.items()}
            
            outputs = model(**encoded)
            # Mean pooling
            attention_mask = encoded['attention_mask']
            embeddings = outputs.last_hidden_state
            embeddings = (embeddings * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1, keepdim=True)
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            
            all_embeddings.append(embeddings.cpu().numpy())
    
    return np.vstack(all_embeddings)


# ============ Code Generation ============

def extract_code_block(raw_response: str) -> Tuple[Optional[str], Optional[str]]:
    """Extract code from <CODE>...</CODE> tags."""
    if raw_response is None:
        return None, "raw_response_is_none"
    
    pattern = r"<CODE>(.*?)</CODE>"
    match = re.search(pattern, raw_response, re.DOTALL)
    
    if match:
        return match.group(1).strip(), None
    
    # Handle truncated response
    pattern_truncated = r"<CODE>(.*)"
    match_truncated = re.search(pattern_truncated, raw_response, re.DOTALL)
    
    if match_truncated:
        code = match_truncated.group(1).strip()
        if code:
            return code, None
    
    # Fallback to markdown
    pattern_markdown = r"```(?:python)?\s*(.*?)```"
    match_markdown = re.search(pattern_markdown, raw_response, re.DOTALL)
    
    if match_markdown:
        return match_markdown.group(1).strip(), "markdown_fallback"
    
    return None, "no_code_block_found"


def postprocess_python(extracted_code: str) -> Tuple[Optional[str], Optional[str]]:
    """Post-process Python code."""
    if not extracted_code:
        return None, "empty_extracted_code"
    
    code = extracted_code.strip()
    lines = code.split('\n')
    while lines and not lines[0].strip():
        lines.pop(0)
    while lines and not lines[-1].strip():
        lines.pop()
    
    if not lines:
        return None, "empty_code_after_cleanup"
    
    has_solver = any('def solver' in line for line in lines)
    if not has_solver:
        return None, "no_solver_function_found"
    
    return '\n'.join(lines), None


def _worker_generate_codes(
    gpu_id: int,
    questions: List[str],
    indices: List[int],
    model_path: str,
    max_retries: int,
    result_queue
):
    """
    Worker function for code generation on a single GPU.
    Runs in a separate process with its own vLLM instance.
    """
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    
    try:
        import vllm
        from transformers import AutoTokenizer
        
        print(f"[GPU {gpu_id}] Initializing vLLM for {len(questions)} questions...")
        
        tokenizer = AutoTokenizer.from_pretrained(
            model_path,
            trust_remote_code=True,
            local_files_only=True
        )
        
        llm = vllm.LLM(
            model=model_path,
            tokenizer=model_path,
            tensor_parallel_size=1,  # Single GPU
            gpu_memory_utilization=0.9,
            trust_remote_code=True,
        )
        
        sample_params = vllm.SamplingParams(
            max_tokens=1024,
            temperature=0.0,
            top_p=1.0,
            stop=["</CODE>"],
        )
        
        # Build prompts
        prompts = []
        for question in questions:
            prompt = QUESTION_TO_CODE_PROMPT.format(question=question)
            messages = [{"role": "user", "content": prompt}]
            
            if tokenizer.chat_template:
                formatted_prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    add_special_tokens=True
                )
            else:
                formatted_prompt = prompt
            
            prompts.append(formatted_prompt)
        
        # Generate with retries
        codes = [None] * len(questions)
        pending_local_indices = list(range(len(questions)))
        
        for attempt in range(max_retries):
            if not pending_local_indices:
                break
            
            print(f"[GPU {gpu_id}] Attempt {attempt + 1}/{max_retries}: {len(pending_local_indices)} pending")
            
            pending_prompts = [prompts[i] for i in pending_local_indices]
            responses = llm.generate(pending_prompts, sampling_params=sample_params, use_tqdm=False)
            
            new_pending = []
            for local_idx, response in zip(pending_local_indices, responses):
                if response and response.outputs:
                    raw_response = response.outputs[0].text
                    extracted_code, _ = extract_code_block(raw_response)
                    
                    if extracted_code:
                        processed_code, _ = postprocess_python(extracted_code)
                        
                        if processed_code:
                            codes[local_idx] = processed_code
                            continue
                
                new_pending.append(local_idx)
            
            pending_local_indices = new_pending
        
        # Put results with original indices
        success_count = sum(1 for c in codes if c is not None)
        print(f"[GPU {gpu_id}] Complete: {success_count}/{len(questions)} success")
        
        for local_idx, orig_idx in enumerate(indices):
            result_queue.put((orig_idx, codes[local_idx]))
            
    except Exception as e:
        print(f"[GPU {gpu_id}] Error: {e}")
        import traceback
        traceback.print_exc()
        # Put None for all assigned questions
        for orig_idx in indices:
            result_queue.put((orig_idx, None))


def generate_codes_vllm_batch(
    questions: List[str],
    model_path: str,
    gpu_ids: List[int],
    max_retries: int = 3
) -> List[Optional[str]]:
    """
    Generate Python codes using vLLM with data parallel processing.
    Each GPU runs an independent vLLM instance, and questions are distributed evenly.
    """
    if not questions:
        return []
    
    import multiprocessing as mp
    
    n_gpus = len(gpu_ids)
    n_questions = len(questions)
    
    print(f"[Memory Bank] Generating codes for {n_questions} questions using {n_gpus} GPUs (data parallel)")
    print(f"[Memory Bank] Model: {model_path}")
    print(f"[Memory Bank] GPUs: {gpu_ids}")
    
    # Distribute questions evenly across GPUs
    questions_per_gpu = [[] for _ in range(n_gpus)]
    indices_per_gpu = [[] for _ in range(n_gpus)]
    
    for i, q in enumerate(questions):
        gpu_idx = i % n_gpus
        questions_per_gpu[gpu_idx].append(q)
        indices_per_gpu[gpu_idx].append(i)
    
    for i, gpu_id in enumerate(gpu_ids):
        print(f"[Memory Bank] GPU {gpu_id}: {len(questions_per_gpu[i])} questions")
    
    # Use multiprocessing for parallel execution
    ctx = mp.get_context('spawn')  # Use spawn to avoid CUDA issues
    result_queue = ctx.Queue()
    
    processes = []
    for i, gpu_id in enumerate(gpu_ids):
        if not questions_per_gpu[i]:
            continue
        
        p = ctx.Process(
            target=_worker_generate_codes,
            args=(
                gpu_id,
                questions_per_gpu[i],
                indices_per_gpu[i],
                model_path,
                max_retries,
                result_queue
            )
        )
        processes.append(p)
        p.start()
    
    # Collect results
    codes = [None] * n_questions
    results_collected = 0
    
    while results_collected < n_questions:
        try:
            orig_idx, code = result_queue.get(timeout=600)  # 10 min timeout
            codes[orig_idx] = code
            results_collected += 1
        except Exception as e:
            print(f"[Memory Bank] Error collecting results: {e}")
            break
    
    # Wait for all processes to finish
    for p in processes:
        p.join(timeout=60)
        if p.is_alive():
            print(f"[Memory Bank] Terminating stuck process {p.pid}")
            p.terminate()
    
    success_count = sum(1 for c in codes if c is not None)
    failed_count = n_questions - success_count
    print(f"[Memory Bank] Code generation complete: {success_count} success, {failed_count} failed")
    
    return codes


# ============ Memory Bank Operations ============

def load_filtered_questions(
    experiment_name: str,
    storage_path: str,
    max_score: float = 0.8,
    min_score: float = 0.3
) -> List[Dict]:
    """
    Load and filter questions from the generated question files.
    This mimics the filtering logic from upload.py.
    
    Args:
        experiment_name: Name of the experiment
        storage_path: Base storage path
        max_score: Maximum score threshold
        min_score: Minimum score threshold
        
    Returns:
        List of filtered question dictionaries
    """
    datas = []
    
    # 动态获取GPU数量，支持6卡/8卡等不同配置
    n_gpus = int(os.getenv("TOTAL_GPU_COUNT", "8"))
    
    # Load from all GPU result files (same as upload.py)
    for i in range(n_gpus):
        file_path = f"{storage_path}/generated_question/{experiment_name}_{i}_results.json"
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
                datas.extend(data)
        except FileNotFoundError:
            print(f"[Memory Bank] File {experiment_name}_{i}_results.json not found, skipping...")
            continue
        except json.JSONDecodeError:
            print(f"[Memory Bank] File {experiment_name}_{i}_results.json is not valid JSON, skipping...")
            continue
    
    # Apply the same filtering as upload.py
    filtered_datas = [
        {
            'question': data['question'],
            'answer': data['answer'],
            'score': data['score']
        }
        for data in datas
        if data.get('score') is not None
        and data['score'] >= min_score
        and data['score'] <= max_score
        and data.get('answer', '') not in ['', 'None']
        and data.get('question', '') != ''
    ]
    
    print(f"[Memory Bank] Loaded {len(datas)} questions, {len(filtered_datas)} passed filtering")
    return filtered_datas


def load_memory_bank(memory_bank_path: str, embedding_type: str = "nl") -> Tuple[List[Dict], Optional[np.ndarray]]:
    """Load existing Memory Bank based on embedding type."""
    if embedding_type == "code":
        questions_file = "question_code.json"
        embeddings_file = "embedding_code.npy"
    else:
        questions_file = "questions.json"
        embeddings_file = "embeddings.npy"
    
    questions_path = os.path.join(memory_bank_path, questions_file)
    embeddings_path = os.path.join(memory_bank_path, embeddings_file)
    
    questions = []
    embeddings = None
    
    if os.path.exists(questions_path):
        with open(questions_path, 'r') as f:
            questions = json.load(f)
        print(f"[Memory Bank] Loaded {len(questions)} existing questions from {questions_file}")
    
    if os.path.exists(embeddings_path):
        embeddings = np.load(embeddings_path)
        print(f"[Memory Bank] Loaded existing embeddings with shape {embeddings.shape} from {embeddings_file}")
    
    return questions, embeddings


def save_memory_bank(
    questions: List[Dict],
    embeddings: np.ndarray,
    memory_bank_path: str,
    embedding_type: str = "nl"
):
    """Save Memory Bank based on embedding type."""
    assert len(questions) == embeddings.shape[0], \
        f"Mismatch: {len(questions)} questions vs {embeddings.shape[0]} embeddings"
    
    os.makedirs(memory_bank_path, exist_ok=True)
    
    if embedding_type == "code":
        questions_file = "question_code.json"
        embeddings_file = "embedding_code.npy"
    else:
        questions_file = "questions.json"
        embeddings_file = "embeddings.npy"
    
    questions_path = os.path.join(memory_bank_path, questions_file)
    embeddings_path = os.path.join(memory_bank_path, embeddings_file)
    
    with open(questions_path, 'w') as f:
        json.dump(questions, f, indent=2, ensure_ascii=False)
    
    np.save(embeddings_path, embeddings)
    
    print(f"[Memory Bank] Saved {len(questions)} questions to {questions_path}")
    print(f"[Memory Bank] Saved embeddings with shape {embeddings.shape} to {embeddings_path}")


def apply_fifo_policy(
    questions: List[Dict],
    embeddings: np.ndarray,
    max_iterations: int,
    current_iteration: int
) -> Tuple[List[Dict], np.ndarray]:
    """
    Apply FIFO policy to remove questions from old iterations.
    
    Args:
        questions: List of question dictionaries with 'iteration' field
        embeddings: numpy array of embeddings
        max_iterations: Maximum number of iterations to keep (-1 for unlimited)
        current_iteration: Current iteration number
        
    Returns:
        Tuple of (filtered questions, filtered embeddings)
    """
    if max_iterations <= 0:
        return questions, embeddings
    
    # Calculate the minimum iteration to keep
    min_iteration = current_iteration - max_iterations + 1
    
    # Filter questions and get indices
    keep_indices = []
    filtered_questions = []
    
    for i, q in enumerate(questions):
        if q.get('iteration', 0) >= min_iteration:
            keep_indices.append(i)
            filtered_questions.append(q)
    
    if not keep_indices:
        print(f"[Memory Bank] No questions to keep, returning empty list")
        return [], np.array([]).reshape(0, embeddings.shape[1] if embeddings.size > 0 else 1024)
    
    filtered_embeddings = embeddings[keep_indices]
    
    removed_count = len(questions) - len(filtered_questions)
    if removed_count > 0:
        print(f"[Memory Bank] FIFO policy: removed {removed_count} questions from iterations < {min_iteration}")
    
    return filtered_questions, filtered_embeddings


def main():
    parser = argparse.ArgumentParser(description="Update Memory Bank with filtered questions")
    parser.add_argument("--experiment_name", type=str, required=True,
                        help="Name of the experiment")
    parser.add_argument("--iteration", type=int, required=True,
                        help="Current iteration number")
    parser.add_argument("--max_score", type=float, default=0.8,
                        help="Maximum score threshold for filtering")
    parser.add_argument("--min_score", type=float, default=0.3,
                        help="Minimum score threshold for filtering")
    parser.add_argument("--max_iterations", type=int, default=-1,
                        help="Maximum iterations to keep in memory (-1 for unlimited)")
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size for embedding computation")
    
    # New arguments for code mode
    parser.add_argument("--embedding_type", type=str, default="nl",
                        choices=["nl", "code"],
                        help="Embedding type: 'nl' for natural language, 'code' for Python code")
    parser.add_argument("--question_to_code_model", type=str, 
                        default="/path/to/models/Qwen2.5-Coder-7B-Instruct",
                        help="Model path for question-to-code generation")
    parser.add_argument("--nl_embedding_model", type=str,
                        default="BAAI/bge-large-en-v1.5",
                        help="Model for NL embeddings")
    parser.add_argument("--code_embedding_model", type=str,
                        default="jinaai/jina-code-embeddings-1.5b",
                        help="Model for code embeddings")
    parser.add_argument("--model_abbr", type=str, default=None,
                        help="Model abbreviation for experiment isolation (e.g., 'qwen_v1')")
    
    args = parser.parse_args()
    
    storage_path = os.getenv("STORAGE_PATH", "/path/to/R-Zero/storage")
    
    # 根据 model_abbr 构建 memory_bank 路径，实现不同实验的隔离
    if args.model_abbr:
        memory_bank_path = os.path.join(storage_path, "memory_bank", args.model_abbr)
    else:
        # 向后兼容：如果没有指定 model_abbr，使用原路径
        memory_bank_path = os.path.join(storage_path, "memory_bank")
    
    print("=" * 60)
    print(f"[Memory Bank] Updating Memory Bank")
    print(f"[Memory Bank] Experiment: {args.experiment_name}")
    print(f"[Memory Bank] Iteration: {args.iteration}")
    print(f"[Memory Bank] Model Abbr: {args.model_abbr}")
    print(f"[Memory Bank] Embedding Type: {args.embedding_type}")
    print(f"[Memory Bank] Storage path: {storage_path}")
    print(f"[Memory Bank] Memory bank path: {memory_bank_path}")
    if args.embedding_type == "code":
        print(f"[Memory Bank] Code generation model: {args.question_to_code_model}")
        print(f"[Memory Bank] Code embedding model: {args.code_embedding_model}")
    else:
        print(f"[Memory Bank] NL embedding model: {args.nl_embedding_model}")
    print("=" * 60)
    
    # Step 1: Load filtered questions from current iteration
    new_questions = load_filtered_questions(
        experiment_name=args.experiment_name,
        storage_path=storage_path,
        max_score=args.max_score,
        min_score=args.min_score
    )
    
    if not new_questions:
        print("[Memory Bank] No new questions to add. Exiting.")
        return
    
    # Add iteration field to new questions
    for q in new_questions:
        q['iteration'] = args.iteration
        q['timestamp'] = datetime.now().isoformat()
    
    # Step 2: Load existing Memory Bank
    existing_questions, existing_embeddings = load_memory_bank(memory_bank_path, args.embedding_type)
    
    # Step 3: Apply FIFO policy if specified
    if existing_questions and existing_embeddings is not None:
        existing_questions, existing_embeddings = apply_fifo_policy(
            existing_questions,
            existing_embeddings,
            args.max_iterations,
            args.iteration
        )
    
    # Step 4: Process new questions based on embedding type
    if args.embedding_type == "code":
        # Generate Python codes first
        question_texts = [q['question'] for q in new_questions]
        
        # Get all available GPUs for code generation
        available_gpus = get_available_gpus()
        if not available_gpus:
            print("[Memory Bank] ERROR: No GPUs available for code generation!")
            return
        
        print(f"[Memory Bank] Using GPUs {available_gpus} for code generation")
        
        generated_codes = generate_codes_vllm_batch(
            question_texts,
            args.question_to_code_model,
            available_gpus,
            max_retries=3
        )
        
        # Filter out failed code generations
        valid_questions = []
        valid_codes = []
        
        for i, (q, code) in enumerate(zip(new_questions, generated_codes)):
            if code is not None:
                q['code'] = code  # Store code in question dict
                valid_questions.append(q)
                valid_codes.append(code)
            else:
                print(f"[Memory Bank] Skipping question {i} due to code generation failure")
        
        if not valid_codes:
            print("[Memory Bank] All code generation failed. Exiting.")
            return
        
        print(f"[Memory Bank] {len(valid_codes)}/{len(new_questions)} questions have valid codes")
        
        # Compute code embeddings
        print(f"[Memory Bank] Computing code embeddings for {len(valid_codes)} codes...")
        new_embeddings = compute_code_embeddings(
            valid_codes,
            args.code_embedding_model,
            args.batch_size
        )
        new_questions = valid_questions
        
    else:
        # NL mode: compute embeddings directly
        question_texts = [q['question'] for q in new_questions]
        print(f"[Memory Bank] Computing NL embeddings for {len(question_texts)} questions...")
        
        new_embeddings = compute_nl_embeddings(
            question_texts,
            args.nl_embedding_model,
            args.batch_size
        )
    
    print(f"[Memory Bank] Computed embeddings with shape {new_embeddings.shape}")
    
    # Step 5: Merge with existing Memory Bank
    if existing_questions and existing_embeddings is not None and existing_embeddings.size > 0:
        all_questions = existing_questions + new_questions
        all_embeddings = np.vstack([existing_embeddings, new_embeddings])
    else:
        all_questions = new_questions
        all_embeddings = new_embeddings
    
    # Step 6: Save updated Memory Bank
    save_memory_bank(all_questions, all_embeddings, memory_bank_path, args.embedding_type)
    
    print("=" * 60)
    print(f"[Memory Bank] Update complete!")
    print(f"[Memory Bank] Total questions in Memory Bank: {len(all_questions)}")
    print("=" * 60)


if __name__ == "__main__":
    main()
