"""
Train Generator with SFT on Solver's Failed Code Generation Attempts (Buggy Code).

This script:
1. Runs inference with a solver model on a code generation dataset using AgentWorkflowEngine
2. Collects failed solutions (code that fails unit tests = "buggy code")
3. Creates SFT data where the Generator learns to produce bugs given (problem, correct_solution)
4. Generates separate train and validation parquet files from different tasks
5. Optionally runs SFT training using verl infrastructure

The idea: If we have a dataset of (problem, correct_solution, buggy_code) triples,
we can train a Generator to produce realistic bugs by imitating the Solver's mistakes.

Usage:
    # Step 1: Generate buggy code data from solver failures (with train/val split)
    python -m examples.bugs.sft_generator_on_solver_failures \
        --mode generate \
        --solver_model Qwen/Qwen2.5-Coder-7B-Instruct \
        --solver_base_url http://localhost:30000/v1 \
        --dataset deepcoder_bugs \
        --split train \
        --output_train solver_failures_sft_train.parquet \
        --output_val solver_failures_sft_val.parquet \
        --val_ratio 0.1 \
        --num_samples 1000 \
        --samples_per_problem 4

    # Step 2: Train Generator on the collected buggy code
    python -m examples.bugs.sft_generator_on_solver_failures \
        --mode train \
        --train_file solver_failures_sft_train.parquet \
        --val_file solver_failures_sft_val.parquet \
        --model Qwen/Qwen2.5-Coder-7B-Instruct
"""
from __future__ import annotations

import argparse
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from dataclasses import dataclass

import pandas as pd

from rllm.agents.agent import Episode, Step, Trajectory
from rllm.data.dataset import DatasetRegistry
from rllm.engine import ModelOutput, RolloutEngine, OpenAIEngine
from rllm.rewards.reward_fn import RewardFunction, code_reward_fn
from rllm.rewards.reward_types import RewardOutput
from rllm.workflows.workflow import Workflow

from examples.bugs.prompts import _build_code_generation_prompt, _build_bug_generator_prompt


@dataclass
class SolverFailure:
    """Represents a solver's failed code generation attempt."""
    task_id: str
    problem: str
    correct_solution: str
    buggy_code: str  # The solver's failed attempt
    error_info: Optional[str] = None
    passed_tests: int = 0
    total_tests: int = 0


def _extract_code_from_response(response: str) -> str:
    """Extract code from markdown blocks if present."""
    if "```python" in response:
        parts = response.split("```python")
        if len(parts) > 1:
            code_part = parts[1].split("```")[0]
            return code_part.strip()
    elif "```" in response:
        parts = response.split("```")
        if len(parts) >= 2:
            return parts[1].strip()
    return response.strip()


def _check_compile_error(meta: Dict[str, Any]) -> bool:
    """Check if there's a compile/syntax error in the metadata."""
    error_msg = meta.get("error", "")
    if isinstance(error_msg, str):
        compile_patterns = ["syntax", "syntaxerror", "indentation", "invalid syntax"]
        if any(p in error_msg.lower() for p in compile_patterns):
            return True
    
    # Also check test_results for compile errors
    test_results = meta.get("test_results", [])
    if isinstance(test_results, list):
        for test in test_results:
            if isinstance(test, dict):
                test_error = str(test.get("error_message", "")).lower()
                if "error during testing:" in test_error:
                    return True
                compile_indicators = [
                    "syntaxerror", "indentationerror", "was never closed",
                    "unexpected", "invalid syntax"
                ]
                if any(p in test_error for p in compile_indicators):
                    return True
    return False


class SolverFailureCollectorWorkflow(Workflow):
    """Workflow that runs a solver on code generation tasks and collects failures.
    
    This workflow:
    1. Takes a task with a problem and correct solution
    2. Runs the solver to generate code
    3. Evaluates the code with unit tests
    4. Marks failures (non-passing code without compile errors) for collection
    
    Used to generate SFT data for training a bug generator.
    """

    def __init__(
        self,
        rollout_engine: RolloutEngine,
        executor: ThreadPoolExecutor,
        reward_function: RewardFunction,
        system_prompt: Optional[str] = None,
        **kwargs,
    ):
        super().__init__(rollout_engine=rollout_engine, executor=executor, **kwargs)
        self.reward_function = reward_function
        self.system_prompt = system_prompt

    async def run(self, task: Dict[str, Any], uid: str, **kwargs) -> Episode:
        """Execute code generation and evaluate for failure collection."""
        self.reset(task, uid)
        
        # Build prompt for code generation
        prompt = _build_code_generation_prompt(task)
        
        messages: List[Dict[str, str]] = []
        if self.system_prompt:
            messages.append({"role": "system", "content": self.system_prompt})
        messages.append({"role": "user", "content": prompt})
        
        # Generate code
        model_output: ModelOutput = await self.rollout_engine.get_model_response(messages)
        raw_response = model_output.content or ""
        generated_code = _extract_code_from_response(raw_response)
        
        # Build trajectory
        chat_completions = messages + [{"role": "assistant", "content": raw_response}]
        step = Step(
            chat_completions=chat_completions,
            action=generated_code,
            model_response=raw_response,
            model_output=model_output,
        )
        trajectory = Trajectory(name="solver", steps=[step])
        
        # Evaluate with reward function
        task_info = task.get("extra_info", task)
        try:
            reward_output: RewardOutput = self.reward_function(task_info=task_info, action=generated_code)
        except Exception as e:
            reward_output = RewardOutput(
                reward=0.0,
                is_correct=False,
                metadata={"error": str(e)},
            )
        
        meta = reward_output.metadata or {}
        is_correct = bool(reward_output.is_correct)
        has_compile_error = _check_compile_error(meta)
        
        # A "valid failure" is one where the code runs but produces wrong results
        is_valid_failure = (not is_correct) and (not has_compile_error)
        
        # Get problem and correct solution for SFT data
        problem = task.get("question", task.get("problem", ""))
        correct_solution = task.get("reference_solution", task.get("ground_truth", task.get("solution", "")))
        
        step.reward = float(reward_output.reward)
        trajectory.reward = float(reward_output.reward)
        
        metrics: Dict[str, Any] = {
            "solver_pass": float(is_correct),
            "has_compile_error": float(has_compile_error),
            "is_valid_failure": float(is_valid_failure),
            "passed_tests": int(meta.get("passed_tests", 0)),
            "total_tests": int(meta.get("total_tests", 0)),
        }
        
        episode = Episode(
            id=uid,
            task=task,
            trajectories=[trajectory],
            is_correct=is_correct,
            metrics=metrics,
            info={
                "problem": problem,
                "correct_solution": correct_solution,
                "generated_code": generated_code,
                "raw_response": raw_response,
                "is_valid_failure": is_valid_failure,
                "error_info": str(meta.get("error", meta.get("error_message", ""))),
            },
        )
        return episode


def _build_generator_sft_messages(problem: str, correct_code: str, buggy_code: str) -> List[Dict[str, str]]:
    """Build chat messages for SFT training the bug generator.
    
    Uses the same prompt as GeneratorSolverWorkflow to ensure consistency.
    The generator learns to produce buggy code given the problem and correct solution.
    """
    # Use the same prompt as GeneratorSolverWorkflow (from prompts.py)
    user_prompt = _build_bug_generator_prompt(problem, correct_code)

    assistant_response = f"""```python
{buggy_code}
```"""

    return [
        {"role": "user", "content": user_prompt},
        {"role": "assistant", "content": assistant_response}
    ]


async def collect_solver_failures_from_data(
    data: List[Dict[str, Any]],
    solver_engine: "OpenAIEngine",
    n_parallel: int = 64,
    samples_per_problem: int = 4,
    split_name: str = "data",
) -> List[SolverFailure]:
    """Run solver model on provided data and collect failures (buggy code).
    
    Uses AgentWorkflowEngine for efficient parallel execution with the
    SolverFailureCollectorWorkflow.
    """
    from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
    
    if not data:
        print(f"No data provided for {split_name}")
        return []
    
    print(f"Processing {len(data)} problems for {split_name}")
    
    # Replicate for multiple samples per problem
    tasks = []
    task_ids = []
    for i in range(samples_per_problem):
        for j, task in enumerate(data):
            task_copy = dict(task)
            uid = f"{split_name}_{j}_{i}"
            task_copy["_uid"] = uid
            task_copy["_original_idx"] = j
            tasks.append(task_copy)
            task_ids.append(uid)
    
    print(f"Total tasks to process for {split_name}: {len(tasks)} ({len(data)} problems x {samples_per_problem} samples)")
    
    # Setup workflow engine
    workflow_args = {
        "reward_function": code_reward_fn,
        "system_prompt": None,
    }
    
    engine = AgentWorkflowEngine(
        workflow_cls=SolverFailureCollectorWorkflow,
        workflow_args=workflow_args,
        rollout_engine=solver_engine,
        n_parallel_tasks=n_parallel,
        retry_limit=2,
        raise_on_error=False,
    )
    
    # Execute all tasks in parallel
    print(f"Generating code and collecting failures for {split_name}...")
    episodes = await engine.execute_tasks(tasks, task_ids=task_ids)
    
    # Shutdown engine
    engine.shutdown()
    
    # Extract failures from episodes
    failures: List[SolverFailure] = []
    for episode in episodes:
        if episode is None:
            continue
        
        info = episode.info or {}
        metrics = episode.metrics or {}
        
        # Only collect valid failures (code runs but fails tests)
        if not info.get("is_valid_failure", False):
            continue
        
        problem = info.get("problem", "")
        correct_solution = info.get("correct_solution", "")
        generated_code = info.get("generated_code", "")
        
        # Skip if we don't have the required fields
        if not correct_solution or not generated_code:
            continue
        
        failures.append(SolverFailure(
            task_id=episode.id,
            problem=problem,
            correct_solution=correct_solution,
            buggy_code=generated_code,
            error_info=info.get("error_info", ""),
            passed_tests=int(metrics.get("passed_tests", 0)),
            total_tests=int(metrics.get("total_tests", 0)),
        ))
    
    print(f"Collected {len(failures)} valid failures out of {len(tasks)} attempts for {split_name} "
          f"({100*len(failures)/max(1,len(tasks)):.1f}% failure rate)")
    
    return failures


async def collect_solver_failures(
    solver_model: str,
    solver_base_url: str,
    dataset_name: str,
    split: str,
    num_samples: int,
    samples_per_problem: int,
    solver_temperature: float = 0.6,
    solver_top_p: float = 0.95,
    n_parallel: int = 64,
    solver_api_key: Optional[str] = None,
    max_prompt_length: int = 8192,
    max_response_length: int = 8192,
    val_ratio: float = 0.0,
) -> tuple[List[SolverFailure], List[SolverFailure]]:
    """Run solver model on dataset and collect failures (buggy code).
    
    Uses AgentWorkflowEngine for efficient parallel execution with the
    SolverFailureCollectorWorkflow.
    
    Returns:
        Tuple of (train_failures, val_failures). If val_ratio is 0, val_failures is empty.
    """
    import random
    
    # Load dataset
    dataset = DatasetRegistry.load_dataset(dataset_name, split)
    if dataset is None:
        raise ValueError(f"Dataset '{dataset_name}' split '{split}' not found. "
                        f"Available: {DatasetRegistry.list_datasets()}")
    
    data = list(dataset.get_data())
    print(f"Loaded {len(data)} examples from {dataset_name}/{split}")
    
    # Limit samples
    if num_samples > 0 and num_samples < len(data):
        data = data[:num_samples]
        print(f"Using first {num_samples} samples")
    
    # Split into train and val
    if val_ratio > 0:
        random.seed(42)  # For reproducibility
        indices = list(range(len(data)))
        random.shuffle(indices)
        n_val = max(1, int(len(data) * val_ratio))
        val_indices = set(indices[:n_val])
        train_data = [data[i] for i in range(len(data)) if i not in val_indices]
        val_data = [data[i] for i in range(len(data)) if i in val_indices]
        print(f"Split data: {len(train_data)} train, {len(val_data)} val (val_ratio={val_ratio})")
    else:
        train_data = data
        val_data = []
    
    # Setup solver engine
    api_key = solver_api_key or os.getenv("OPENAI_API_KEY", "EMPTY")
    if not api_key:
        api_key = "EMPTY"
    
    try:
        from transformers import AutoTokenizer
        tokenizer = AutoTokenizer.from_pretrained(solver_model, trust_remote_code=True)
    except Exception as e:
        print(f"Warning: Could not load tokenizer for {solver_model}: {e}")
        tokenizer = None
    
    solver_engine = OpenAIEngine(
        model=solver_model,
        tokenizer=tokenizer,
        base_url=solver_base_url,
        api_key=api_key,
        max_prompt_length=max_prompt_length,
        max_response_length=max_response_length,
        sampling_params={
            "temperature": solver_temperature,
            "top_p": solver_top_p,
        },
    )
    
    # Collect failures for train data
    train_failures = await collect_solver_failures_from_data(
        data=train_data,
        solver_engine=solver_engine,
        n_parallel=n_parallel,
        samples_per_problem=samples_per_problem,
        split_name="train",
    )
    
    # Collect failures for val data if val_ratio > 0
    val_failures: List[SolverFailure] = []
    if val_data:
        val_failures = await collect_solver_failures_from_data(
            data=val_data,
            solver_engine=solver_engine,
            n_parallel=n_parallel,
            samples_per_problem=samples_per_problem,
            split_name="val",
        )
    
    print(f"Total: {len(train_failures)} train failures, {len(val_failures)} val failures")
    
    return train_failures, val_failures


def create_sft_dataset(failures: List[SolverFailure]) -> List[Dict[str, Any]]:
    """Convert solver failures to SFT training format."""
    sft_data = []
    
    for failure in failures:
        if not failure.correct_solution or not failure.buggy_code:
            continue
        
        messages = _build_generator_sft_messages(
            problem=failure.problem,
            correct_code=failure.correct_solution,
            buggy_code=failure.buggy_code,
        )
        
        sft_data.append({
            "messages": messages,
            "task_id": failure.task_id,
            "passed_tests": failure.passed_tests,
            "total_tests": failure.total_tests,
        })
    
    return sft_data


def save_sft_dataset(sft_data: List[Dict[str, Any]], output_path: str) -> pd.DataFrame:
    """Save SFT dataset to parquet file."""
    df = pd.DataFrame(sft_data)
    df.to_parquet(output_path, index=False)
    print(f"Saved {len(sft_data)} examples to {output_path}")
    
    # Print stats
    if sft_data:
        msg_lengths = [
            sum(len(m["content"]) for m in ex["messages"])
            for ex in sft_data
        ]
        print(f"Message lengths: min={min(msg_lengths)}, max={max(msg_lengths)}, "
              f"avg={sum(msg_lengths)//len(msg_lengths)}")
    
    return df


def push_to_huggingface(
    sft_data: List[Dict[str, Any]],
    repo_id: str,
    private: bool = True,
    split: str = "train",
) -> str:
    """Push SFT dataset to Hugging Face Hub.
    
    Args:
        sft_data: List of SFT examples with 'messages' field
        repo_id: HuggingFace repo ID (e.g., 'username/dataset-name')
        private: Whether to make the dataset private
        split: Dataset split name (default: 'train')
    
    Returns:
        URL of the uploaded dataset
    """
    from datasets import Dataset
    
    # Convert to HF-compatible format
    # Store messages as-is (list of dicts) - HF handles this natively
    hf_data = []
    for example in sft_data:
        hf_example = {
            "messages": example["messages"],  # List of {"role": str, "content": str}
            "task_id": str(example.get("task_id", "")),
            "passed_tests": int(example.get("passed_tests", 0)),
            "total_tests": int(example.get("total_tests", 0)),
        }
        hf_data.append(hf_example)
    
    # Create HuggingFace Dataset without explicit features
    # HF will infer the schema automatically for nested structures
    dataset = Dataset.from_list(hf_data)
    
    print(f"Created HuggingFace Dataset with {len(dataset)} examples")
    print(f"Features: {dataset.features}")
    print(f"Sample messages[0]: {dataset[0]['messages'][0]}")
    
    # Push to Hub
    print(f"Pushing to HuggingFace Hub: {repo_id} (private={private})")
    dataset.push_to_hub(
        repo_id,
        split=split,
        private=private,
    )
    
    url = f"https://huggingface.co/datasets/{repo_id}"
    print(f"Successfully pushed to: {url}")
    return url


async def generate_mode(args):
    """Generate SFT data from solver failures."""
    train_failures, val_failures = await collect_solver_failures(
        solver_model=args.solver_model,
        solver_base_url=args.solver_base_url,
        dataset_name=args.dataset,
        split=args.split,
        num_samples=args.num_samples,
        samples_per_problem=args.samples_per_problem,
        solver_temperature=args.solver_temperature,
        solver_top_p=args.solver_top_p,
        n_parallel=args.n_parallel,
        solver_api_key=args.solver_api_key,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        val_ratio=args.val_ratio,
    )
    
    if not train_failures:
        print("No train failures collected! The solver might be too good or dataset too small.")
        return
    
    # Create and save train dataset
    train_sft_data = create_sft_dataset(train_failures)
    save_sft_dataset(train_sft_data, args.output_train)
    
    # Create and save val dataset if we have val failures
    val_sft_data = []
    if val_failures:
        val_sft_data = create_sft_dataset(val_failures)
        save_sft_dataset(val_sft_data, args.output_val)
    elif args.val_ratio > 0:
        print("Warning: No validation failures collected despite val_ratio > 0")
    
    # Push to HuggingFace if requested
    if args.push_to_hf and args.hf_repo_id:
        print("=" * 60)
        print("Pushing dataset to HuggingFace Hub")
        print("=" * 60)
        # Push train split
        push_to_huggingface(
            sft_data=train_sft_data,
            repo_id=args.hf_repo_id,
            private=args.hf_private,
            split="train",
        )
        # Push val split if available
        if val_sft_data:
            push_to_huggingface(
                sft_data=val_sft_data,
                repo_id=args.hf_repo_id,
                private=args.hf_private,
                split="validation",
            )


def train_mode(args):
    """Run SFT training using examples.sft.train_math_sft."""
    import subprocess
    import sys
    
    cmd = [
        sys.executable, "-m", "torch.distributed.run",
        "--standalone",
        f"--nnodes={args.nnodes}",
        f"--nproc_per_node={args.nproc_per_node}",
        "-m", "examples.sft.train_math_sft",
        f"model.partial_pretrain={args.model}",
        "model.trust_remote_code=true",
        "model.enable_gradient_checkpointing=true",
        f"model.lora_rank={args.lora_rank}",
        f"model.lora_alpha={args.lora_alpha}",
        f"trainer.total_epochs={args.epochs}",
        f"data.train_batch_size={args.batch_size}",
        f"data.micro_batch_size_per_gpu={args.micro_batch_size}",
        f"data.max_length={args.max_length}",
        "data.truncation=right",
        "data.multiturn.enable=true",
        "data.multiturn.messages_key=messages",
        f"data.train_files={args.train_file}",
        f"data.val_files={args.val_file or args.train_file}",
        f"trainer.default_local_dir={args.output_dir}",
        f"trainer.logger=['console', 'wandb']" if args.use_wandb else "trainer.logger=['console']",
        f"trainer.project_name={args.wandb_project}",
        f"optim.lr={args.lr}",
    ]
    
    print("Running SFT training command:")
    print(" ".join(cmd))
    subprocess.run(cmd)


def main():
    parser = argparse.ArgumentParser(
        description="Train Generator SFT on Solver's Code Generation Failures"
    )
    parser.add_argument("--mode", type=str, choices=["generate", "train", "both"],
                       default="both", help="Mode: generate data, train, or both")
    
    # Data generation arguments
    gen_group = parser.add_argument_group("Data Generation")
    gen_group.add_argument("--solver_model", type=str, 
                          default="Qwen/Qwen2.5-Coder-7B-Instruct",
                          help="Solver model to generate code failures from")
    gen_group.add_argument("--solver_base_url", type=str,
                          default="http://localhost:30000/v1",
                          help="Base URL for solver model API")
    gen_group.add_argument("--solver_api_key", type=str, default=None,
                          help="API key for solver model (defaults to OPENAI_API_KEY)")
    gen_group.add_argument("--solver_temperature", type=float, default=0.6,
                          help="Sampling temperature for solver")
    gen_group.add_argument("--solver_top_p", type=float, default=0.95,
                          help="Top-p for solver sampling")
    gen_group.add_argument("--dataset", type=str, default="deepcoder_bugs",
                          help="Dataset name to use for code generation tasks")
    gen_group.add_argument("--split", type=str, default="train",
                          help="Dataset split to use")
    gen_group.add_argument("--num_samples", type=int, default=1000,
                          help="Number of problems to sample (0 = all)")
    gen_group.add_argument("--samples_per_problem", type=int, default=4,
                          help="Number of generation attempts per problem")
    gen_group.add_argument("--n_parallel", type=int, default=128,
                          help="Number of parallel workflow instances (AgentWorkflowEngine)")
    gen_group.add_argument("--max_prompt_length", type=int, default=8192,
                          help="Maximum prompt length for code generation")
    gen_group.add_argument("--max_response_length", type=int, default=8192,
                          help="Maximum response length for code generation")
    gen_group.add_argument("--val_ratio", type=float, default=0.1,
                          help="Ratio of data to use for validation (0 = no val split)")
    gen_group.add_argument("--output_train", type=str, default="solver_failures_sft_train.parquet",
                          help="Output path for train SFT data parquet")
    gen_group.add_argument("--output_val", type=str, default="solver_failures_sft_val.parquet",
                          help="Output path for validation SFT data parquet")
    
    # HuggingFace Hub arguments
    hf_group = parser.add_argument_group("HuggingFace Hub")
    hf_group.add_argument("--push_to_hf", action="store_true",
                         help="Push generated dataset to HuggingFace Hub")
    hf_group.add_argument("--hf_repo_id", type=str, default=None,
                         help="HuggingFace repo ID (e.g., 'username/dataset-name')")
    hf_group.add_argument("--hf_private", action="store_true", default=True,
                         help="Make the HuggingFace dataset private (default: True)")
    hf_group.add_argument("--hf_public", action="store_true",
                         help="Make the HuggingFace dataset public")
    
    # Training arguments
    train_group = parser.add_argument_group("Training")
    train_group.add_argument("--train_file", type=str, default="solver_failures_sft.parquet",
                            help="Path to training data parquet")
    train_group.add_argument("--val_file", type=str, default=None,
                            help="Path to validation data parquet (defaults to train_file)")
    train_group.add_argument("--model", type=str, 
                            default="Qwen/Qwen2.5-Coder-7B-Instruct",
                            help="Model to train")
    train_group.add_argument("--output_dir", type=str, default="outputs/generator_sft",
                            help="Output directory for checkpoints")
    train_group.add_argument("--epochs", type=int, default=3,
                            help="Number of training epochs")
    train_group.add_argument("--batch_size", type=int, default=4,
                            help="Training batch size")
    train_group.add_argument("--micro_batch_size", type=int, default=2,
                            help="Micro batch size per GPU")
    train_group.add_argument("--max_length", type=int, default=8192,
                            help="Maximum sequence length")
    train_group.add_argument("--lr", type=float, default=1e-5,
                            help="Learning rate")
    train_group.add_argument("--lora_rank", type=int, default=32,
                            help="LoRA rank")
    train_group.add_argument("--lora_alpha", type=int, default=16,
                            help="LoRA alpha")
    train_group.add_argument("--nnodes", type=int, default=1,
                            help="Number of nodes")
    train_group.add_argument("--nproc_per_node", type=int, default=2,
                            help="Number of processes per node (GPUs)")
    train_group.add_argument("--use_wandb", action="store_true",
                            help="Enable wandb logging")
    train_group.add_argument("--wandb_project", type=str, default="generator-sft",
                            help="Wandb project name")
    
    args = parser.parse_args()
    
    # Handle hf_public flag (overrides hf_private)
    if args.hf_public:
        args.hf_private = False
    
    if args.mode in ["generate", "both"]:
        print("=" * 60)
        print("Step 1: Generating SFT data from solver failures")
        print("=" * 60)
        asyncio.run(generate_mode(args))
    
    if args.mode in ["train", "both"]:
        print("=" * 60)
        print("Step 2: Training Generator with SFT")
        print("=" * 60)
        train_mode(args)


if __name__ == "__main__":
    main()

