import sys
import argparse
from pathlib import Path
from datetime import datetime
from types import SimpleNamespace
import pandas as pd

sys.path.insert(0, str(Path(__file__).parent))

from API.llm_clients import create_llm_client
from high_fidelity import HighFidelityBlackbox
from koh.optimizer import KOHOptimizer
from koh.utils import numpy_to_dict_list
import numpy as np

SEEDS = [1, 2, 3, 4, 5] 
RUN_TAG_OVERRIDE = None
CLEAN_PREVIOUS = False
SINGLE_SEED = None
CLEAR_CACHE_BETWEEN_SEEDS = False

MAX_HF_ITERATIONS = None
N_CANDIDATES = None
Q = None
N_INITIAL_POINTS = None
TASK_NAME = "Sandwich"
OUTPUT_ROOT = Path("outputs")
_TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M')
RUN_TAG = f"{TASK_NAME}_{_TIMESTAMP}"


def _prepare_output_dirs(task_name: str):
    base_dir = OUTPUT_ROOT / task_name
    koh_dir = base_dir / "KOH"
    for path in [base_dir, koh_dir]:
        path.mkdir(parents=True, exist_ok=True)
    return {"base": base_dir, "KOH": koh_dir}


def _clear_method_cache(cache_dir: Path, file_prefix: str):
    """Clear cache files for specified method."""
    if not file_prefix:
        return
    removed = 0
    for path in cache_dir.glob(f"{file_prefix}_*"):
        if path.is_file():
            path.unlink()
            removed += 1
    if removed > 0:
        print(f"  ✓ Cleared cache files: {removed} ({cache_dir.name}, prefix={file_prefix})")


OUTPUT_DIRS = _prepare_output_dirs(TASK_NAME)
KOH_DATA_DIR = OUTPUT_DIRS["KOH"]

SPACE_CONFIG = {
    "names": [
        "multigrain_bread", "whole_wheat_bread", "sourdough_bread",
        "chicken_protein", "tuna_protein", "tofu_protein", "hummus_protein", "egg_protein",
        "low_fat_cheese_dairy", "cheddar_cheese", "swiss_cheese_dairy",
        "collards", "cabbage", "onion_vegetables", "tomato_vegetables",
        "mayo_sauce", "olive_oil",
        "apples", "orange", "banana"
    ],
    "types": ["float"] * 20,
    "normalize": False
}

def _get_total_score_range():
    """Get Total_Score range."""
    return [0.0, 250.0]

_total_score_range = _get_total_score_range()
LLM_CONFIG = {
    "llm_type": "intern_s1",
    "value_range": _total_score_range
}

KOH_CONFIG = {
    "max_hf_iterations": 30,
    "n_candidates": 5000,
    "q": 2,
    "n_initial_points": 3,
    "mismatch_threshold": 0.8,
    "force_hf_after_n_lf": 20,
    "gp_training_iter": 50,
    "max_loops": 1000,
    "acquisition_type": "ucb",
    "acquisition_beta": 2.0
}

FIXED_INIT_POINTS = None


def _get_config_value(config_value, config_key, default_config):
    """Get value from config or default."""
    return config_value if config_value is not None else default_config[config_key]

def run_koh_optimization(seed_override: int = None, file_prefix_override: str = None):
    """Run KOH optimization.
    
    Args:
        seed_override: Override random seed (for batch runs)
        file_prefix_override: Override file prefix (for batch runs)
    """
    current_seed = seed_override if seed_override is not None else SINGLE_SEED
    current_file_prefix = file_prefix_override if file_prefix_override is not None else RUN_TAG
    
    seed_suffix = f" (Seed: {current_seed})" if current_seed is not None else ""
    print(f"\n{'='*80}")
    print(f"{'KOH Optimization - Sandwich Task' + seed_suffix:^80}")
    print(f"{'='*80}")
    
    print("\nInitializing components...")
    
    output_dir = KOH_DATA_DIR
    print(f"  ✓ Output directory: {output_dir}")
    print(f"  ✓ Run prefix: {current_file_prefix}")
    if current_seed is not None:
        print(f"  ✓ Random seed: {current_seed} (only affects sampling, not LLM)")
    if CLEAN_PREVIOUS:
        removed = 0
        for file in output_dir.glob(f"{current_file_prefix}_*.csv"):
            file.unlink()
            removed += 1
        print(f"  ✓ Cleared old files with prefix: {removed}")
    
    from API.llm_config_utils import create_llm_config_from_model_name
    
    llm_model = LLM_CONFIG.get("model", "intern-s1")
    llm_config = create_llm_config_from_model_name(
        llm_model=llm_model,
        base_config=LLM_CONFIG.copy()
    )
    
    llm_client = create_llm_client(llm_config)
    print(f"  ✓ LLM client ready ({llm_config.llm_type}, model: {llm_model})")
    print(f"  ✓ Total_Score range: {LLM_CONFIG['value_range']}")
    
    hf_blackbox = HighFidelityBlackbox(
        task_name=TASK_NAME,
        csv_path=None,
        feature_names=SPACE_CONFIG["names"],
        target_col="Total_Score"
    )
    print(f"  ✓ High-fidelity blackbox ready")
    
    actual_bounds = hf_blackbox.bounds
    print(f"  ✓ Search space: {SPACE_CONFIG['names']}")
    print(f"  ✓ Number of variables: {len(SPACE_CONFIG['names'])}")
    print(f"  ✓ Variable ranges: {list(zip(SPACE_CONFIG['names'], actual_bounds))}")
    
    koh_config_dict = KOH_CONFIG.copy()
    if current_seed is not None:
        koh_config_dict["random_seed"] = current_seed
    
    optimizer = KOHOptimizer(
        task_name=TASK_NAME,
        task_data_dir=str(KOH_DATA_DIR),
        feature_names=SPACE_CONFIG["names"],
        feature_types=SPACE_CONFIG["types"],
        bounds=actual_bounds,
        target_name="Total_Score",
        llm_client=llm_client,
        hf_blackbox=hf_blackbox,
        llm_config=llm_config,
        koh_config=SimpleNamespace(**koh_config_dict),
        file_prefix=current_file_prefix
    )
    print(f"  ✓ KOH optimizer ready")
    
    print(f"\n{'='*80}")
    print(f"{'Starting KOH Optimization':^80}")
    print(f"{'='*80}")
    
    max_iterations = _get_config_value(MAX_HF_ITERATIONS, "max_hf_iterations", KOH_CONFIG)
    n_initial_points = _get_config_value(N_INITIAL_POINTS, "n_initial_points", KOH_CONFIG)
    q = _get_config_value(Q, "q", KOH_CONFIG)
    
    fixed_initial_points_dict = None
    if FIXED_INIT_POINTS and len(FIXED_INIT_POINTS) > 0:
        fixed_array = np.array(FIXED_INIT_POINTS)
        fixed_initial_points_dict = numpy_to_dict_list(fixed_array, SPACE_CONFIG["names"])
        print(f"  ✓ Using fixed initial points: {len(fixed_initial_points_dict)} points")
    
    optimizer.run(
        max_iterations=max_iterations,
        n_initial_points=n_initial_points,
        q=q,
        fixed_initial_points=fixed_initial_points_dict
    )
    
    print(f"\n{'='*80}")
    print(f"{'KOH Optimization Complete!':^80}")
    print(f"{'='*80}")


def main():
    """Main function."""
    global RUN_TAG, SEEDS, SINGLE_SEED
    
    # Parse command line arguments (optional)
    parser = argparse.ArgumentParser(description="Sandwich task run script")
    parser.add_argument("--seed", type=int, help="Single run random seed (override SINGLE_SEED)")
    parser.add_argument("--seeds", type=str, help="Comma-separated random seed list, run one by one (override SEEDS)")
    args = parser.parse_args()
    
    if args.seeds:
        SEEDS = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
        SINGLE_SEED = None
    elif args.seed is not None:
        SINGLE_SEED = args.seed
    
    # If run prefix specified, override global RUN_TAG
    if RUN_TAG_OVERRIDE:
        RUN_TAG = RUN_TAG_OVERRIDE
    
    # Batch run logic
    def run_batch_seeds(run_func, prefix_template, description="", cache_dir: Path = None):
        """Generic function for batch running multiple seeds."""
        seed_list = [int(s) for s in SEEDS]
        print(f"\n{'='*80}")
        print(f"{f'Batch Run Mode ({description})':^80}" if description else f"{'Batch Run Mode':^80}")
        print(f"{'='*80}")
        print(f"Will run {len(seed_list)} random seeds in order: {seed_list}")
        print(f"Note: These random seeds only affect sampling (numpy, torch)")
        
        for idx, seed in enumerate(seed_list, 1):
            print(f"\n{'='*80}")
            print(f"{'Seed %d (%d/%d)' % (seed, idx, len(seed_list)):^80}")
            print(f"{'='*80}")
            run_func(seed_override=seed, file_prefix_override=prefix_template.format(seed=seed))
            print(f"\n✓ Seed {seed} run complete")
            
            if CLEAR_CACHE_BETWEEN_SEEDS and cache_dir is not None:
                _clear_method_cache(cache_dir, prefix_template.format(seed=seed))
        
        print(f"\n{'='*80}")
        print(f"{'🎉 All Seeds Run Complete!':^80}")
        print(f"{'='*80}")
    
    # Method mapping
    prefix_template = f"{TASK_NAME}_{_TIMESTAMP}_seed{{seed}}"
    description = "KOH"
    cache_dir = KOH_DATA_DIR
    
    if SEEDS:
        run_batch_seeds(run_koh_optimization, prefix_template, description, cache_dir)
    else:
        run_koh_optimization()


if __name__ == "__main__":
    main()

