import argparse
import asyncio
import json
import os
from datetime import datetime

from dotenv import find_dotenv, load_dotenv
from transformers import AutoTokenizer

from generator_flow import BugGeneratorWorkflow
from data_utils import load_data
from rllm.engine import OpenAIEngine
from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.rewards.reward_fn import code_reward_fn


def _to_fenced_python(code_or_text: str) -> str:
    s = (code_or_text or "").strip("\n")
    if s.startswith("```"):
        return s
    return f"```python\n{s}\n```"


def _extract_solver_prompt_from_episode(episode) -> str | None:
    """
    Extract the user prompt sent to the solver (bug_fixer trajectory) from an Episode.
    Returns None if no solver trajectory is present.
    """
    for traj in getattr(episode, "trajectories", []) or []:
        if traj.name != "bug_fixer" or not getattr(traj, "steps", None):
            continue
        step = traj.steps[0]
        chat = getattr(step, "chat_completions", None) or []
        # The solver prompt is the last user message before assistant response.
        for msg in reversed(chat):
            if isinstance(msg, dict) and msg.get("role") == "user":
                return msg.get("content")
    return None


def episodes_to_hf_rows(episodes, solver_model: str | None = None) -> list[dict]:
    """
    Convert Episode objects from BugGeneratorWorkflow into a HuggingFace-dataset-ready list of dicts.
    Stores the generated buggy solution, whether it's valid, the solver's attempted fix (if any),
    and whether the fix passes.
    """
    rows: list[dict] = []
    for ep in episodes:
        task = ep.task or {}
        metrics = ep.metrics or {}

        bug_code = None
        solver_code = None
        solver_prompt = _extract_solver_prompt_from_episode(ep)

        # Trajectories: bug_generator is always present; bug_fixer present if solver enabled.
        for traj in getattr(ep, "trajectories", []) or []:
            if traj.name == "bug_generator" and traj.steps:
                bug_code = traj.steps[0].action
            if traj.name == "bug_fixer" and traj.steps:
                solver_code = traj.steps[0].action

        row = {
            "uid": ep.id,
            "data_source": task.get("data_source", "livecodebench"),
            "question": task.get("question"),
            "starter_code": task.get("starter_code", ""),
            "ground_truth": task.get("ground_truth"),
            "reference_solution": task.get("reference_solution"),
            "buggy_solution": _to_fenced_python(bug_code or ""),
            "bug_valid": bool(metrics.get("bug_valid", 0.0)),
            "bug_has_compile_error": bool(metrics.get("bug_has_compile_error", 0.0)),
            "bug_total_tests": metrics.get("bug_total_tests", None),
            "bug_passed_tests": metrics.get("bug_passed_tests", None),
            "solver_model": solver_model,
            "solver_prompt": solver_prompt,
            "solver_solution": _to_fenced_python(solver_code or "") if solver_code is not None else None,
            "solver_pass": bool(metrics.get("solver_pass", 0.0)),
            "generator_reward": float(metrics.get("generator_reward", 0.0)),
        }
        rows.append(row)
    return rows


def evaluate_results(results):
    """Evaluate the results and compute metrics."""
    from collections import defaultdict
    
    total_episodes = len(results)
    if total_episodes == 0:
        print("No results to evaluate.")
        return
    
    # Aggregate metrics
    bug_valid_count = 0
    generator_reward_count = 0
    episode_correct_count = 0
    solver_pass_count = 0
    
    bug_total_tests_sum = 0
    bug_passed_tests_sum = 0
    bug_test_counts = []
    
    problem_stats = defaultdict(lambda: {
        "total": 0,
        "bug_valid": 0,
        "generator_reward": 0,
        "episode_correct": 0,
        "solver_pass": 0,
    })
    
    for episode in results:
        metrics = episode.metrics or {}
        problem = episode.task.get("question", "")[:100]  # Truncate for display
        
        bug_valid_count += int(metrics.get("bug_valid", 0))
        generator_reward_count += int(metrics.get("generator_reward", 0))
        episode_correct_count += int(episode.is_correct)
        solver_pass_count += int(metrics.get("solver_pass", 0))
        
        if "bug_total_tests" in metrics:
            bug_total_tests_sum += metrics["bug_total_tests"]
        if "bug_passed_tests" in metrics:
            bug_passed_tests_sum += metrics["bug_passed_tests"]
            bug_test_counts.append(metrics["bug_passed_tests"])
        
        # Per-problem stats
        problem_stats[problem]["total"] += 1
        problem_stats[problem]["bug_valid"] += int(metrics.get("bug_valid", 0))
        problem_stats[problem]["generator_reward"] += int(metrics.get("generator_reward", 0))
        problem_stats[problem]["episode_correct"] += int(episode.is_correct)
        problem_stats[problem]["solver_pass"] += int(metrics.get("solver_pass", 0))
    
    # Print summary statistics
    print("\n" + "="*80)
    print("📊 BUG GENERATOR WORKFLOW RESULTS (vs Static Solver)")
    print("="*80)
    print(f"Total episodes: {total_episodes}")
    print(f"\nOverall Metrics:")
    print(f"  Bug Valid Rate: {bug_valid_count}/{total_episodes} ({100*bug_valid_count/total_episodes:.1f}%)")
    print(f"  Generator Reward Rate: {generator_reward_count}/{total_episodes} ({100*generator_reward_count/total_episodes:.1f}%)")
    print(f"  Solver Pass Rate: {solver_pass_count}/{total_episodes} ({100*solver_pass_count/total_episodes:.1f}%)")
    print(f"  Episode Correct Rate: {episode_correct_count}/{total_episodes} ({100*episode_correct_count/total_episodes:.1f}%)")
    
    if bug_total_tests_sum > 0:
        avg_total_tests = bug_total_tests_sum / total_episodes
        avg_passed_tests = bug_passed_tests_sum / total_episodes
        print(f"\nBug Test Statistics:")
        print(f"  Average total tests per bug: {avg_total_tests:.2f}")
        print(f"  Average passed tests per bug: {avg_passed_tests:.2f}")
        print(f"  Average failed tests per bug: {avg_total_tests - avg_passed_tests:.2f}")
    
    print("="*80)


def print_reward_distributions(results_dict):
    """Print distribution of generator rewards and solver pass rates from results dict."""
    from collections import Counter
    
    generator_rewards = []
    solver_passes = []
    
    for episode_dict in results_dict:
        metrics = episode_dict.get("metrics", {})
        generator_reward = metrics.get("generator_reward", 0.0)
        solver_pass = metrics.get("solver_pass", 0.0)
        generator_rewards.append(generator_reward)
        solver_passes.append(solver_pass)
    
    total = len(generator_rewards)
    if total == 0:
        print("No episodes found for reward distribution.")
        return
    
    # Count distributions
    gen_counter = Counter(generator_rewards)
    solver_counter = Counter(solver_passes)
    
    print("\n" + "="*80)
    print("📈 REWARD DISTRIBUTIONS")
    print("="*80)
    print(f"Total episodes: {total}")
    
    print(f"\n🔧 Generator Rewards:")
    for reward_value in sorted(gen_counter.keys()):
        count = gen_counter[reward_value]
        percentage = 100.0 * count / total
        print(f"  {reward_value:.1f}: {count:5d} episodes ({percentage:5.1f}%)")
    
    if any(solver_passes):
        print(f"\n✅ Solver Pass Rates:")
        for pass_value in sorted(solver_counter.keys()):
            count = solver_counter[pass_value]
            percentage = 100.0 * count / total
            print(f"  {pass_value:.1f}: {count:5d} episodes ({percentage:5.1f}%)")
    
    print("="*80)


def load_json_results(json_file):
    """Load results from a JSON file."""
    if not os.path.exists(json_file):
        print(f"Error: JSON file not found: {json_file}")
        return None
    
    print(f"Loading results from: {json_file}")
    with open(json_file, "r") as f:
        results_dict = json.load(f)
    
    print(f"Loaded {len(results_dict)} episodes")
    return results_dict


def exclude_token_ids(data):
    """Recursively remove prompt_ids and completion_ids from the data structure."""
    if isinstance(data, dict):
        # Create a new dict without prompt_ids and completion_ids
        result = {}
        for key, value in data.items():
            if key not in ["prompt_ids", "completion_ids"]:
                result[key] = exclude_token_ids(value)
        return result
    elif isinstance(data, list):
        # Recursively process list items
        return [exclude_token_ids(item) for item in data]
    else:
        # Return other types as-is
        return data


def print_sample_episodes(results, n_samples=3):
    """Print detailed information for a few sample episodes."""
    print("\n" + "="*80)
    print(f"📝 SAMPLE EPISODES (showing first {min(n_samples, len(results))})")
    print("="*80)
    
    for idx, episode in enumerate(results[:n_samples]):
        print(f"\n--- Episode {idx + 1} ---")
        print(f"Task ID: {episode.id}")
        print(f"Question: {episode.task.get('question', '')[:200]}...")
        
        # Show bug generator trajectory
        bug_traj = None
        solver_traj = None
        for traj in episode.trajectories:
            if traj.name == "bug_generator":
                bug_traj = traj
            elif traj.name == "bug_fixer":
                solver_traj = traj
        
        if bug_traj and bug_traj.steps:
            bug_code = bug_traj.steps[0].action
            print(f"\n🐛 Bug Generator Output (first 300 chars):")
            print(bug_code[:300] + "..." if len(bug_code) > 300 else bug_code)
            print(f"Reward: {bug_traj.steps[0].reward}")
        
        if solver_traj and solver_traj.steps:
            fixed_code = solver_traj.steps[0].action
            print(f"\n🔧 Solver Output (first 300 chars):")
            print(fixed_code[:300] + "..." if len(fixed_code) > 300 else fixed_code)
        
        print(f"\nMetrics: {episode.metrics}")
        print(f"Episode Correct: {episode.is_correct}")
        print("-"*80)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Bug Generator workflow on DeepCoder tasks")
    parser.add_argument("--n_tasks", type=int, default=0, help="Number of tasks to process (default: 3)")
    parser.add_argument("--n_repeats", type=int, default=1, help="Number of times to repeat each task (default: 1)")
    parser.add_argument("--split", type=str, default="train", help="Dataset split to use (default: train)")
    parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-Coder-7B-Instruct", help="Model name (default: Qwen/Qwen2.5-Coder-7B-Instruct)")
    parser.add_argument("--base_url", type=str, default="http://localhost:30000/v1", help="Base URL for the API (default: http://localhost:30000/v1)")
    parser.add_argument("--api_key", type=str, default="None", help="API key (default: None)")
    parser.add_argument("--n_parallel", type=int, default=64, help="Number of parallel tasks (default: 32)")
    parser.add_argument("--max_prompt_length", type=int, default=8192, help="Maximum prompt length (default: 8192)")
    parser.add_argument("--max_response_length", type=int, default=8192, help="Maximum response length (default: 8192)")
    parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature (default: 0.6)")
    parser.add_argument("--top_p", type=float, default=0.95, help="Sampling top_p (default: 0.95)")
    parser.add_argument("--generator_system_prompt", type=str, default=None, help="Optional system prompt for bug generator")
    parser.add_argument("--solver_model", type=str, default="openai/gpt-oss-20b", help="Model name for static solver (e.g., gpt-4o-mini). If None, solver is disabled.")
    parser.add_argument("--solver_base_url", type=str, default="http://localhost:30001/v1", help="Base URL for solver API (default: same as --base_url)")
    parser.add_argument("--solver_temperature", type=float, default=0.6, help="Sampling temperature for solver (default: 0.6)")
    parser.add_argument("--solver_top_p", type=float, default=0.95, help="Sampling top_p for solver (default: 0.95)")
    parser.add_argument("--solver_system_prompt", type=str, default=None, help="Optional system prompt for static solver")
    parser.add_argument("--save_results", action="store_true", help="Save results to JSON file")
    parser.add_argument("--output_dir", type=str, default="logs", help="Directory to save results (default: logs)")
    parser.add_argument(
        "--print_solver_prompts",
        action="store_true",
        help="Print the exact prompt sent to the solver for each episode (requires --solver_model).",
    )
    parser.add_argument(
        "--save_hf_rows",
        action="store_true",
        help="Save HuggingFace-dataset-ready rows (buggy_solution + validity + solver fix/pass) to JSON.",
    )
    parser.add_argument(
        "--push_to_hub",
        action="store_true",
        help="Push the saved HF rows as a dataset to the HuggingFace Hub (requires --hub_repo_id).",
    )
    parser.add_argument(
        "--hub_repo_id",
        type=str,
        default=os.getenv("HF_DATASET_REPO_ID", ""),
        help="HF dataset repo id, e.g. `username/my_dataset` (or set HF_DATASET_REPO_ID).",
    )
    parser.add_argument(
        "--hub_split",
        type=str,
        default=None,
        help="Split name to push (default: --split).",
    )
    parser.add_argument(
        "--hub_private",
        action="store_true",
        help="Create/push the Hub repo as private (requires permissions).",
    )
    parser.add_argument(
        "--hub_token",
        type=str,
        default=os.getenv("HF_TOKEN", os.getenv("HUGGINGFACE_HUB_TOKEN", "")),
        help="HuggingFace token (or set HF_TOKEN / HUGGINGFACE_HUB_TOKEN). If empty, relies on local HF login.",
    )
    parser.add_argument("--print_samples", type=int, default=3, help="Number of sample episodes to print in detail (default: 3)")
    parser.add_argument("--load_json", type=str, default=None, help="Load results from a JSON file and print reward distributions")
    parser.add_argument("--dataset", type=str, default="deepcoder", help="Dataset to use: 'deepcoder', 'bigcodebench', or custom dataset name (default: deepcoder)")
    
    args = parser.parse_args()
    
    # Load environment variables from .env file
    env_path = find_dotenv()
    if env_path:
        print(f"Loading environment variables from: {env_path}")
        load_dotenv(env_path)
    else:
        print("Warning: No .env file found. Will use environment variables only.")
    
    # If --load_json is specified, load and print distributions, then exit
    if args.load_json:
        results_dict = load_json_results(args.load_json)
        if results_dict is not None:
            print_reward_distributions(results_dict)
        exit(0)
    
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    
    # Initialize tokenizer
    print(f"Loading tokenizer for model: {args.model}")
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    
    # Initialize rollout engine for generator
    print(f"Initializing generator rollout engine...")
    rollout_engine = OpenAIEngine(
        model=args.model,
        tokenizer=tokenizer,
        max_prompt_length=args.max_prompt_length,
        max_response_length=args.max_response_length,
        base_url=args.base_url,
        api_key=args.api_key,
        sampling_params={
            "temperature": args.temperature,
            "top_p": args.top_p,
        },
    )
    
    # Initialize solver rollout engine if solver model is specified
    solver_rollout_engine = None
    if args.solver_model:
        print(f"Initializing solver rollout engine for model: {args.solver_model}")
        solver_base_url = args.solver_base_url if args.solver_base_url else args.base_url
        # Load solver API key from environment variable (from .env file)
        solver_api_key = os.getenv("OPENAI_API_KEY")
        if not solver_api_key:
            raise ValueError("OPENAI_API_KEY not found in environment. Please set it in your .env file or environment variables.")
        # Strip whitespace (common issue with .env files)
        solver_api_key = solver_api_key.strip()
        if not solver_api_key:
            raise ValueError("OPENAI_API_KEY is empty after stripping whitespace. Please check your .env file.")
        # Debug: print masked API key to confirm it's loaded
        masked_key = solver_api_key[:7] + "*" * (len(solver_api_key) - 11) + solver_api_key[-4:] if len(solver_api_key) > 11 else "*" * len(solver_api_key)
        print(f"  Loaded API key: {masked_key}")
        
        # Check if solver model should use the OpenAI chat-completions mode (no tokenizer).
        # This is true for:
        # - OpenAI-hosted models
        # - OpenAI-like model IDs such as `openai/gpt-oss-*` served behind an OpenAI-compatible endpoint (e.g., vLLM)
        solver_model_l = str(args.solver_model).lower()
        solver_base_url_l = str(solver_base_url).lower()
        is_openai_model = (
            solver_model_l.startswith("gpt-")
            or solver_model_l.startswith("o1")
            or solver_model_l.startswith("o3")
            or solver_model_l.startswith("openai/")
            or "gpt-oss" in solver_model_l
            or "api.openai.com" in solver_base_url_l
        )
        
        # Load tokenizer for solver (skip for OpenAI models, use same tokenizer if same model, otherwise load new one)
        if is_openai_model:
            solver_tokenizer = None
            print("  Using OpenAI API directly (no tokenizer needed)")
        elif args.solver_model == args.model:
            solver_tokenizer = tokenizer
        else:
            solver_tokenizer = AutoTokenizer.from_pretrained(args.solver_model)
        
        try:
            solver_rollout_engine = OpenAIEngine(
                model=args.solver_model,
                tokenizer=solver_tokenizer,
                max_prompt_length=args.max_prompt_length,
                max_response_length=args.max_response_length,
                base_url=solver_base_url,
                api_key=solver_api_key,
                sampling_params={
                    "temperature": args.solver_temperature,
                    "top_p": args.solver_top_p,
                },
            )
        except AssertionError as e:
            # If the tokenizer's chat template doesn't match our parser equivalence tests,
            # fall back to chat-completions mode (tokenizer=None). This is common when serving
            # models behind OpenAI-compatible endpoints where we don't have a matching HF chat template.
            print(f"Warning: Solver tokenizer/chat-template parser equivalence failed: {e}")
            print("Falling back to chat-completions mode for solver (tokenizer=None).")
            solver_rollout_engine = OpenAIEngine(
                model=args.solver_model,
                tokenizer=None,
                max_prompt_length=args.max_prompt_length,
                max_response_length=args.max_response_length,
                base_url=solver_base_url,
                api_key=solver_api_key,
                sampling_params={
                    "temperature": args.solver_temperature,
                    "top_p": args.solver_top_p,
                },
            )
    else:
        print("No solver model specified. Running generator-only workflow.")
    
    # Create workflow engine
    print(f"Creating workflow engine...")
    engine = AgentWorkflowEngine(
        workflow_cls=BugGeneratorWorkflow,
        workflow_args={
            "reward_function": code_reward_fn,
            "generator_system_prompt": args.generator_system_prompt,
            "solver_rollout_engine": solver_rollout_engine,
            "solver_system_prompt": args.solver_system_prompt,
        },
        rollout_engine=rollout_engine,
        config=None,
        n_parallel_tasks=args.n_parallel,
        retry_limit=1,
    )
    
    # Load tasks
    print(f"\nLoading tasks from dataset '{args.dataset}' split '{args.split}'...")
    all_tasks = load_data(
        dataset_name=args.dataset,
        n=args.n_repeats,
        split=args.split,
    )
    
    if not all_tasks:
        print("No tasks loaded. Exiting.")
        exit(1)
    
    # Limit to n_tasks if specified
    if args.n_tasks > 0:
        all_tasks = all_tasks[:args.n_tasks]
    
    print(f"Loaded {len(all_tasks)} tasks")
    print(f"Configuration:")
    print(f"  Generator Model: {args.model}")
    if args.solver_model:
        print(f"  Solver Model: {args.solver_model}")
        print(f"  Solver Temperature: {args.solver_temperature}, Top-p: {args.solver_top_p}")
    print(f"  Parallel tasks: {args.n_parallel}")
    print(f"  Generator Temperature: {args.temperature}, Top-p: {args.top_p}")
    print(f"  Max prompt length: {args.max_prompt_length}")
    print(f"  Max response length: {args.max_response_length}")
    
    # Execute workflow
    print(f"\n🚀 Executing workflow on {len(all_tasks)} tasks...")
    results = asyncio.run(engine.execute_tasks(all_tasks))
    
    # Evaluate results
    print("\n📊 Evaluating results...")
    evaluate_results(results)

    # Also print the two metrics requested explicitly
    total_eps = len(results)
    if total_eps:
        bug_valid_count = sum(int((ep.metrics or {}).get("bug_valid", 0.0)) for ep in results)
        solver_pass_count = sum(int((ep.metrics or {}).get("solver_pass", 0.0)) for ep in results)
        print(f"\nBug valid rate: {bug_valid_count}/{total_eps} ({100.0*bug_valid_count/total_eps:.1f}%)")
        print(f"Solve rate (solver_pass): {solver_pass_count}/{total_eps} ({100.0*solver_pass_count/total_eps:.1f}%)")

    # Print solver prompts for each episode if requested
    if args.print_solver_prompts:
        if not args.solver_model:
            print("Warning: --print_solver_prompts was set but --solver_model is None; no solver prompts to print.")
        else:
            print("\n" + "=" * 80)
            print("🧾 SOLVER PROMPTS (one per episode)")
            print("=" * 80)
            for i, ep in enumerate(results):
                prompt = _extract_solver_prompt_from_episode(ep) or ""
                print(f"\n--- Episode {i} | uid={ep.id} ---\n")
                print(prompt)
    
    # Print sample episodes
    if args.print_samples > 0:
        print_sample_episodes(results, n_samples=args.print_samples)
    
    # Save results if requested
    if args.save_results:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(
            args.output_dir,
            f"generator_flow_results_{timestamp}.json"
        )
        
        # Convert episodes to dict format for JSON serialization
        # Remove prompt_ids and completion_ids to reduce file size
        results_dict = [exclude_token_ids(episode.to_dict()) for episode in results]
        
        with open(output_file, "w") as f:
            json.dump(results_dict, f, indent=2)
        
        print(f"\n💾 Results saved to: {output_file}")
        
        # Print reward distributions
        print_reward_distributions(results_dict)

    # Save HuggingFace-dataset-ready rows + optional push
    if args.save_hf_rows or args.push_to_hub:
        os.makedirs(args.output_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        split_name = args.hub_split if args.hub_split is not None else args.split
        hf_rows = episodes_to_hf_rows(results, solver_model=args.solver_model)

        hf_json_path = os.path.join(args.output_dir, f"lcb_bugbench_rows_{split_name}_{timestamp}.json")
        with open(hf_json_path, "w") as f:
            json.dump(hf_rows, f, indent=2)
        print(f"\n💾 HF rows saved to: {hf_json_path} (rows={len(hf_rows)})")

        if args.push_to_hub:
            repo_id = str(args.hub_repo_id).strip()
            if not repo_id:
                raise ValueError("--push_to_hub requires --hub_repo_id (or env HF_DATASET_REPO_ID)")
            # Import here so we don't require datasets/huggingface_hub unless pushing.
            from datasets import Dataset as HFDataset
            from datasets import DatasetDict as HFDatasetDict

            ds_hf = HFDataset.from_list(hf_rows)
            dsd = HFDatasetDict({str(split_name): ds_hf})
            dsd.push_to_hub(
                repo_id,
                token=str(args.hub_token).strip() or None,
                private=bool(args.hub_private),
            )
            print(f"✅ Pushed to Hub: {repo_id} (split={split_name})")
    
    print("\n✅ Done!")
