# Set tokenizers parallelism BEFORE any imports to avoid warnings when forking processes
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import json
import random
import argparse
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, TimeoutError as FuturesTimeoutError
from src.inference import LLM
from src.mesh_utils import load_mesh_from_stl, normalize_mesh, mesh_to_pointcloud, cadquery_to_mesh
from src.evaluation import chamfer_distance, compute_statistics, extract_code
from config.code_generation import CODE_GENERATION_SYSTEM_PROMPT, CODE_GENERATION_USER_PROMPT_TEMPLATE


# Path to data sources and ground truth meshes (set env vars or use defaults)
IMPROVED_DATA_DIR = os.environ.get("IMPROVED_DATA_DIR", "./sft/improved_data")
SPLIT_UIDS_FILE = os.environ.get("SPLIT_UIDS_FILE", "./data/train_val_test.json")
GT_MESH_DIR = os.environ.get("GT_MESH_DIR", "./data/text2cad/deepcad_mesh")
TEXT2CAD_CSV = os.environ.get("TEXT2CAD_CSV", "./data/text2cad_v1.1.csv")
DEFAULT_MODEL = "Qwen/Qwen2.5-7B-Instruct"
SFT_MODEL = os.environ.get("SFT_MODEL", "./saves/qwen2_5-7b/full/sft/checkpoint-100")


def load_split_uids(split: str = "test") -> set:
    """
    Load UIDs from train_val_test.json for the specified split.
    
    Args:
        split: "train" or "test"
        
    Returns:
        Set of UIDs
    """
    if not os.path.exists(SPLIT_UIDS_FILE):
        print(f"Warning: Split file not found: {SPLIT_UIDS_FILE}")
        return set()
    
    with open(SPLIT_UIDS_FILE, 'r') as f:
        data = json.load(f)
    
    uid_key = "test_uids" if split == "test" else "train_uids"
    return set(data.get(uid_key, []))


def load_improved_data_samples(data_dir: str = IMPROVED_DATA_DIR, uid_filter: set = None):
    """
    Load samples from improved_data files.
    
    Args:
        data_dir: Directory containing improved_*.json files
        uid_filter: Optional set of UIDs to filter (only include these)
        
    Returns:
        List of samples with uid, modified_prompt
    """
    import glob
    
    data_files = sorted(glob.glob(os.path.join(data_dir, 'improved_*.json')))
    samples = []
    
    for data_file in data_files:
        with open(data_file, 'r') as f:
            data = json.load(f)
        
        # Handle both 'samples' (joined format) and 'results' (original format)
        items = data.get('samples', data.get('results', []))
        
        for item in items:
            uid = item.get('uid')
            
            # For 'results' format, check status
            if 'results' in data:
                status = item.get('status')
                # Only include successful samples
                if status not in ['accepted', 'fixed', 'regenerated']:
                    continue
            
            # Filter by UID if specified
            if uid_filter and uid not in uid_filter:
                continue
            
            modified_prompt = item.get('final_modified_prompt')
            if not uid or not modified_prompt:
                continue
            
            samples.append({
                'uid': uid,
                'modified_prompt': modified_prompt,
                'chamfer_distance': item.get('final_chamfer_distance')
            })
    
    return samples


def load_misleading_data_samples(data_dir: str, uid_filter: set = None):
    """
    Load samples from misleading data files (misleading_batch_*.json format).
    
    Args:
        data_dir: Directory containing misleading_batch_*.json files
        uid_filter: Optional set of UIDs to filter (only include these)
        
    Returns:
        List of samples with uid, misleading_description (as prompt)
    """
    import glob
    
    data_files = sorted(glob.glob(os.path.join(data_dir, 'misleading_batch_*.json')))
    samples = []
    
    for data_file in data_files:
        with open(data_file, 'r') as f:
            data = json.load(f)
        
        for result in data.get('results', []):
            uid = result.get('uid')
            if not uid:
                continue
            
            # Filter by UID if specified
            if uid_filter and uid not in uid_filter:
                continue
            
            misleading_description = result.get('misleading_description')
            if not misleading_description:
                continue
            
            samples.append({
                'uid': uid,
                'misleading_description': misleading_description,
                'original_prompt': result.get('original_prompt'),
                'original_cd': result.get('original_cd'),
                'config_name': result.get('config_name'),
                'k': result.get('k')
            })
    
    return samples


def load_expert_prompts_from_csv(uids: list, csv_path: str = TEXT2CAD_CSV) -> dict:
    """
    Load expert prompts from text2cad_v1.1.csv for given UIDs.
    
    Note: Filtered data UIDs are like "00079736", but CSV UIDs are like "0007/00079736".
    We match by extracting the numeric suffix from CSV UIDs.
    
    Args:
        uids: List of UIDs to load expert prompts for
        csv_path: Path to the text2cad CSV file
        
    Returns:
        Dict mapping UID -> expert prompt
    """
    import csv
    
    uid_set = set(uids)
    uid_to_expert = {}
    
    print(f"Loading expert prompts from CSV for {len(uids)} UIDs...")
    
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            csv_uid = row.get('uid', '')
            # CSV UID format: "0007/00079736", filtered UID format: "00079736"
            # Extract the numeric part after the slash
            if '/' in csv_uid:
                numeric_uid = csv_uid.split('/')[-1]
            else:
                numeric_uid = csv_uid
            
            if numeric_uid in uid_set:
                uid_to_expert[numeric_uid] = row.get('expert', '')
                if len(uid_to_expert) == len(uid_set):
                    break  # Found all UIDs, stop reading
    
    print(f"Found expert prompts for {len(uid_to_expert)}/{len(uids)} UIDs")
    return uid_to_expert


def load_processed_uids(output_dir: str) -> set:
    """
    Load UIDs that have already been processed from existing batch files.
    
    Args:
        output_dir: Directory containing batch_*.json files
        
    Returns:
        Set of already processed UIDs
    """
    import glob
    
    processed_uids = set()
    batch_files = glob.glob(os.path.join(output_dir, 'batch_*.json'))
    
    for batch_file in batch_files:
        try:
            with open(batch_file, 'r') as f:
                data = json.load(f)
            for result in data.get('results', []):
                uid = result.get('uid')
                if uid:
                    processed_uids.add(uid)
        except Exception as e:
            print(f"Warning: Failed to read {batch_file}: {e}")
    
    return processed_uids


def _evaluate_code_worker(code: str, uid: str, gt_mesh_dir: str, n_points: int):
    """
    Worker function for evaluating code in a separate process.
    This function is called in a separate process to allow timeout.
    """
    from pathlib import Path
    from src.mesh_utils import load_mesh_from_stl, normalize_mesh, mesh_to_pointcloud, cadquery_to_mesh
    from src.evaluation import chamfer_distance
    
    result = {
        'success': False,
        'chamfer_distance': None,
        'error': None
    }
    
    # Find ground truth mesh file
    gt_mesh_path = Path(gt_mesh_dir) / f"{uid}.stl"
    
    if not gt_mesh_path.exists():
        result['error'] = f"Ground truth mesh not found: {gt_mesh_path}"
        return result
    
    try:
        # Load and normalize ground truth mesh
        gt_mesh = load_mesh_from_stl(str(gt_mesh_path))
        gt_mesh_normalized = normalize_mesh(gt_mesh)
        gt_points = mesh_to_pointcloud(gt_mesh_normalized, n_points)
        
        # Convert generated code to mesh
        gen_mesh = cadquery_to_mesh(code)
        gen_mesh_normalized = normalize_mesh(gen_mesh)
        gen_points = mesh_to_pointcloud(gen_mesh_normalized, n_points)
        
        # Compute Chamfer Distance
        cd = chamfer_distance(gt_points, gen_points)
        
        result['success'] = True
        result['chamfer_distance'] = cd
        
        # Determine quality level
        if cd < 0.0001:
            result['quality'] = 'excellent'
        elif cd < 0.0002:
            result['quality'] = 'good'
        elif cd < 0.001:
            result['quality'] = 'acceptable'
        else:
            result['quality'] = 'poor'
            
    except Exception as e:
        result['error'] = str(e)
    
    return result


def evaluate_against_gt_mesh(code: str, uid: str, gt_mesh_dir: str = GT_MESH_DIR, n_points: int = 8192, timeout: float = 5.0) -> dict:
    """
    Evaluate generated code against ground truth mesh file using Chamfer Distance.
    
    Args:
        code: Generated CadQuery code
        uid: Sample UID (used to find ground truth mesh file)
        gt_mesh_dir: Directory containing ground truth mesh files
        n_points: Number of points for point cloud sampling
        timeout: Timeout in seconds (default: 5.0)
        
    Returns:
        Evaluation result dict with chamfer_distance
    """
    import time
    
    result = {
        'success': False,
        'chamfer_distance': None,
        'error': None
    }
    
    # Check if ground truth mesh file exists first (fast check, no timeout needed)
    gt_mesh_path = Path(gt_mesh_dir) / f"{uid}.stl"
    if not gt_mesh_path.exists():
        result['error'] = f"Ground truth mesh not found: {gt_mesh_path}"
        return result
    
    # Use ProcessPoolExecutor with timeout to prevent hanging
    try:
        executor = ProcessPoolExecutor(max_workers=1)
        try:
            future = executor.submit(_evaluate_code_worker, code, uid, gt_mesh_dir, n_points)
            result = future.result(timeout=timeout)
        finally:
            # Shutdown executor to ensure cleanup
            executor.shutdown(wait=False, cancel_futures=True)
    except FuturesTimeoutError:
        result['success'] = False
        result['error'] = f"Evaluation timed out after {timeout} seconds (code may be too complex or contain infinite loops)"
    except Exception as e:
        elapsed = time.time() - start_time
        result['success'] = False
        result['error'] = f"Evaluation failed: {str(e)}"
        print(f"[Eval {uid}] ERROR after {elapsed:.2f}s: {str(e)}", flush=True)
    
    return result


def build_messages(description: str) -> list:
    """
    Build chat messages for CAD code generation.
    
    Args:
        description: Text description of the 3D shape
        
    Returns:
        List of message dicts
    """
    user_prompt = CODE_GENERATION_USER_PROMPT_TEMPLATE.format(description=description)
    return [
        {"role": "system", "content": CODE_GENERATION_SYSTEM_PROMPT},
        {"role": "user", "content": user_prompt}
    ]


def build_reflection_message(code: str, eval_result: dict) -> str:
    """
    Build a reflection message containing feedback for the model.
    
    Args:
        code: The generated code that needs improvement
        eval_result: Evaluation result dict
        
    Returns:
        Feedback message string
    """
    feedback = f"Your previous code attempt:\n```python\n{code}\n```\n\n"
    
    if eval_result.get('error'):
        feedback += f"Error when running the code:\n{eval_result['error']}\n\n"
        feedback += "Please fix the syntax or runtime error and generate corrected code."
    elif eval_result.get('chamfer_distance') is not None:
        cd = eval_result['chamfer_distance']
        feedback += f"The code runs successfully but the shape accuracy (Chamfer Distance) is {cd:.6f}.\n"
        feedback += "A lower Chamfer Distance means better accuracy. Target is < 0.0002.\n\n"
        feedback += "Please analyze the discrepancy and generate improved code that more precisely matches the described shape."
    
    feedback += "\n\nOutput only the corrected Python code, no explanations."
    
    return feedback


def inference_with_reflection(
    llm,
    description: str,
    uid: str,
    max_iterations: int = 3,
    cd_threshold: float = 0.0002
) -> dict:
    """
    Run inference with self-reflection loop.
    
    The model is prompted multiple times. After each attempt, if the result
    is not satisfactory (CD >= threshold or error), feedback is provided
    to the model for refinement.
    
    Args:
        llm: LLM instance
        description: Text description of the 3D shape
        uid: Sample UID (used to find ground truth mesh file)
        max_iterations: Maximum number of refinement iterations
        cd_threshold: Stop if chamfer distance is below this threshold
        
    Returns:
        Dict with final result and iteration history
    """
    history = []
    messages = build_messages(description)
    
    for iteration in range(max_iterations):
        print(f"  Iteration {iteration + 1}/{max_iterations}...")
        
        try:
            response = llm.inference(messages)
            code = extract_code(response)
        except Exception as e:
            history.append({
                'iteration': iteration + 1,
                'success': False,
                'error': f"Inference error: {str(e)}",
                'code': None
            })
            break
        
        # Evaluate the generated code against ground truth mesh
        eval_result = evaluate_against_gt_mesh(code, uid)
        eval_result['iteration'] = iteration + 1
        eval_result['code'] = code
        history.append(eval_result)
        
        # Check if we should stop
        if eval_result['success'] and eval_result['chamfer_distance'] < cd_threshold:
            print(f"    CD: {eval_result['chamfer_distance']:.6f} - Target reached!")
            break
        
        # Log current result
        if eval_result['success']:
            print(f"    CD: {eval_result['chamfer_distance']:.6f} - Needs improvement")
        else:
            print(f"    Error: {eval_result['error'][:50]}...")
        
        # If not last iteration, add reflection feedback
        if iteration < max_iterations - 1:
            reflection_msg = build_reflection_message(code, eval_result)
            messages.append({"role": "assistant", "content": response})
            messages.append({"role": "user", "content": reflection_msg})
    
    # Return final result
    final_result = history[-1] if history else {'success': False, 'error': 'No iterations completed'}
    return {
        'final': final_result,
        'history': history,
        'iterations': len(history)
    }


def main(model: str = None, n_samples: int = None, seed: int = 42, 
         reflection: bool = False, max_iterations: int = 10, batch_size: int = 8,
         output_dir: str = "results/test", use_expert: bool = False, split: str = "test",
         misleading_dir: str = None):
    """
    Main function for testing CAD code generation.
    
    Args:
        model: Model name from config or HuggingFace/local path.
               Examples: "anthropic/claude-3.5-sonnet", "gpt-4.1-2025-04-14",
                        "Qwen/Qwen2.5-7B-Instruct", "/path/to/checkpoint-327"
        n_samples: Number of samples to test (None = all)
        seed: Random seed for sampling
        reflection: Use self-reflection loop for improved generation
        max_iterations: Max refinement iterations (only used with reflection)
        batch_size: Batch size for inference (only used without reflection)
        output_dir: Directory to save batch results
        use_expert: Use expert prompts from text2cad CSV instead of modified prompts
        split: "train" or "test" - which split to use from train_val_test.json
    """
    import numpy as np
    from datetime import datetime
    random.seed(seed)
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Configuration
    model_name = model if model else DEFAULT_MODEL
    threshold = 0.0002
    
    # Load samples - either from misleading data or improved data
    if misleading_dir:
        # Load misleading data (no split filtering needed - use all samples)
        print(f"Loading misleading samples from: {misleading_dir}")
        good_samples = load_misleading_data_samples(misleading_dir, uid_filter=None)
        print(f"Found {len(good_samples)} misleading samples")
        
        # Use misleading_description as prompt
        for sample in good_samples:
            sample['prompt'] = sample['misleading_description']
    else:
        # Load UIDs for the specified split
        print(f"Loading {split} UIDs from: {SPLIT_UIDS_FILE}")
        uid_filter = load_split_uids(split)
        print(f"Found {len(uid_filter)} {split} UIDs")
        
        # Load samples from improved_data, filtered by split UIDs
        print(f"Loading samples from improved_data: {IMPROVED_DATA_DIR}")
        good_samples = load_improved_data_samples(IMPROVED_DATA_DIR, uid_filter)
        print(f"Found {len(good_samples)} samples in {split} split")
        
        # If using expert prompts, load them from CSV
        if use_expert:
            uids = [s['uid'] for s in good_samples]
            uid_to_expert = load_expert_prompts_from_csv(uids)
            
            # Only keep samples that have expert prompts - no fallback
            filtered_samples = []
            skipped_count = 0
            for sample in good_samples:
                uid = sample['uid']
                if uid in uid_to_expert and uid_to_expert[uid]:
                    sample['prompt'] = uid_to_expert[uid]
                    filtered_samples.append(sample)
                else:
                    skipped_count += 1
            
            if skipped_count > 0:
                print(f"WARNING: Skipped {skipped_count} samples (UIDs not found in CSV)")
            good_samples = filtered_samples
            print(f"Using EXPERT prompts: {len(good_samples)} samples")
        else:
            # Use modified_prompt
            for sample in good_samples:
                sample['prompt'] = sample['modified_prompt']
            print(f"Using MODIFIED prompts from filtered_data")
    
    if len(good_samples) == 0:
        print("No samples found.")
        return
    
    # Check for already processed UIDs (resume support)
    processed_uids = load_processed_uids(output_dir)
    if processed_uids:
        original_count = len(good_samples)
        good_samples = [s for s in good_samples if s['uid'] not in processed_uids]
        skipped = original_count - len(good_samples)
        print(f"Resume: Found {len(processed_uids)} already processed UIDs, skipping {skipped} samples")
    
    if len(good_samples) == 0:
        print("All samples already processed. Nothing to do.")
        return
    
    # Limit samples if specified
    if n_samples is not None and n_samples < len(good_samples):
        random.shuffle(good_samples)
        good_samples = good_samples[:n_samples]
        print(f"Randomly selected {n_samples} samples (seed={seed})")
    
    # Initialize LLM
    print(f"\nLoading model: {model_name}")
    llm = LLM(model_name=model_name)
    print(f"Using {'API' if llm.use_api else 'transformer'} inference")
    if reflection:
        print(f"Reflection mode: max {max_iterations} iterations per sample")
    else:
        print(f"Batch size: {batch_size}")
    print(f"Evaluating against ground truth meshes in: {GT_MESH_DIR}")
    print(f"Saving results to: {output_dir}")
    
    results = []
    
    if reflection:
        # Reflection mode: process one at a time
        for i, sample in enumerate(good_samples):
            uid = sample['uid']
            description = sample['prompt']
            
            print(f"\n[{i+1}/{len(good_samples)}] Processing sample: {uid}")
            
            result = inference_with_reflection(
                llm, description, uid,
                max_iterations=max_iterations,
                cd_threshold=threshold
            )
            result['uid'] = uid
            
            final = result['final']
            if final.get('success'):
                print(f"  Final CD: {final['chamfer_distance']:.6f} after {result['iterations']} iteration(s)")
            else:
                print(f"  Failed after {result['iterations']} iteration(s): {final.get('error', 'Unknown error')[:50]}")
            
            results.append(result)
    else:
        # Batch inference mode
        total_batches = (len(good_samples) + batch_size - 1) // batch_size
        
        # Find starting batch number for resume
        import glob as glob_module
        existing_batches = glob_module.glob(os.path.join(output_dir, 'batch_*.json'))
        start_batch_num = len(existing_batches) + 1
        
        for batch_idx in range(total_batches):
            batch_start = batch_idx * batch_size
            batch_end = min(batch_start + batch_size, len(good_samples))
            batch_samples = good_samples[batch_start:batch_end]
            batch_num = start_batch_num + batch_idx
            
            print(f"\n{'='*60}")
            print(f"Batch {batch_num} (processing {batch_idx + 1}/{total_batches}, {len(batch_samples)} samples)")
            print(f"{'='*60}")
            
            # Build messages for batch
            messages_list = [build_messages(s['prompt']) for s in batch_samples]
            
            # Batch inference
            try:
                responses = llm.batch_inference(messages_list, batch_size=batch_size)
            except Exception as e:
                print(f"Batch inference failed: {e}")
                for sample in batch_samples:
                    results.append({'uid': sample['uid'], 'success': False, 'error': str(e)})
                continue
            
            # Evaluate each response
            batch_results = []
            for sample, response in zip(batch_samples, responses):
                uid = sample['uid']
                
                if response.startswith("ERROR:"):
                    eval_result = {'uid': uid, 'success': False, 'error': response}
                else:
                    code = extract_code(response)
                    eval_result = evaluate_against_gt_mesh(code, uid)
                    eval_result['uid'] = uid
                    eval_result['generated_code'] = code
                
                if eval_result.get('success'):
                    print(f"  {uid}: CD={eval_result['chamfer_distance']:.6f} ({eval_result['quality']})")
                else:
                    error_msg = eval_result.get('error', 'Unknown')[:50]
                    print(f"  {uid}: Failed - {error_msg}")
                
                batch_results.append(eval_result)
                results.append(eval_result)
            
            # Save batch results
            batch_output_path = os.path.join(output_dir, f"batch_{batch_num:04d}.json")
            batch_data = {
                'metadata': {
                    'batch_number': batch_num,
                    'batch_size': len(batch_results),
                    'model': model_name,
                    'prompt_type': 'expert' if use_expert else 'modified',
                    'processed_at': datetime.now().isoformat()
                },
                'results': batch_results
            }
            with open(batch_output_path, 'w') as f:
                json.dump(batch_data, f, indent=2, default=str)
            print(f"  Saved to: {batch_output_path}")
    
    # Summary
    prompt_type = "EXPERT" if use_expert else "MODIFIED"
    print("\n" + "="*60)
    print(f"SUMMARY {'(WITH REFLECTION)' if reflection else ''} - {prompt_type} PROMPTS")
    print("="*60)
    
    if reflection:
        successful = [r for r in results if r['final'].get('success')]
        print(f"Successful: {len(successful)}/{len(results)}")
        
        if successful:
            avg_iters = sum(r['iterations'] for r in results) / len(results)
            print(f"Average iterations: {avg_iters:.2f}")
            
            final_cds = [r['final']['chamfer_distance'] for r in successful]
            print(f"\nModel: {model_name}")
            print(f"  Mean CD:   {np.mean(final_cds):.6f}")
            print(f"  Median CD: {np.median(final_cds):.6f}")
            print(f"  Min CD:    {np.min(final_cds):.6f}")
            print(f"  Max CD:    {np.max(final_cds):.6f}")
            
            quality_counts = {}
            for r in successful:
                q = r['final'].get('quality', 'unknown')
                quality_counts[q] = quality_counts.get(q, 0) + 1
            print(f"\nQuality breakdown:")
            for q in ['excellent', 'good', 'acceptable', 'poor']:
                if q in quality_counts:
                    print(f"  {q}: {quality_counts[q]}")
    else:
        successful = [r for r in results if r.get('success')]
        print(f"Successful: {len(successful)}/{len(results)}")
        
        if successful:
            stats = compute_statistics([{'cd': r['chamfer_distance']} for r in successful])
            print(f"\nModel: {model_name}")
            print(f"  Mean CD:   {stats['mean_cd']:.6f}")
            print(f"  Median CD: {stats['median_cd']:.6f}")
            print(f"  Min CD:    {stats['min_cd']:.6f}")
            print(f"  Max CD:    {stats['max_cd']:.6f}")
            
            quality_counts = {}
            for r in successful:
                q = r.get('quality', 'unknown')
                quality_counts[q] = quality_counts.get(q, 0) + 1
            print(f"\nQuality breakdown:")
            for q in ['excellent', 'good', 'acceptable', 'poor']:
                if q in quality_counts:
                    print(f"  {q}: {quality_counts[q]}")
    
    # Save results
    suffix = '_reflection' if reflection else ''
    suffix += '_expert' if use_expert else '_modified'
    output_path = os.path.join(output_dir, f"test_results{suffix}.json")
    with open(output_path, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    print(f"\nResults saved to {output_path}")
    
    # Cleanup
    llm.unload()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test CAD generation with any model (API or local)')
    parser.add_argument('--model', type=str, default=None,
                       help='Model name from config or HuggingFace/local path')
    parser.add_argument('--n_samples', type=int, default=None,
                       help='Number of samples to test (default: all)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed for sampling')
    parser.add_argument('--batch_size', type=int, default=8,
                       help='Batch size for inference (zero-shot mode only)')
    parser.add_argument('--output_dir', type=str, default='results/test',
                       help='Directory to save batch results')
    parser.add_argument('--reflection', action='store_true',
                       help='Use self-reflection loop (disables batch inference)')
    parser.add_argument('--max_iterations', type=int, default=10,
                       help='Max refinement iterations (only with --reflection)')
    parser.add_argument('--use_expert', action='store_true',
                       help='Use expert prompts from text2cad CSV instead of modified prompts')
    parser.add_argument('--split', type=str, choices=['train', 'test'], default='test',
                       help='Which split to use: train or test (default: test)')
    parser.add_argument('--misleading_dir', type=str, default=None,
                       help='Directory containing misleading_batch_*.json files (overrides split)')
    
    args = parser.parse_args()
    main(model=args.model, n_samples=args.n_samples, seed=args.seed,
         reflection=args.reflection, max_iterations=args.max_iterations,
         batch_size=args.batch_size, output_dir=args.output_dir,
         use_expert=args.use_expert, split=args.split, misleading_dir=args.misleading_dir)

