import sys
import argparse
from pathlib import Path
from datetime import datetime
from types import SimpleNamespace
import pandas as pd

sys.path.insert(0, str(Path(__file__).parent.absolute()))

from API.llm_clients import create_llm_client
from high_fidelity import HighFidelityBlackbox
from koh.optimizer import KOHOptimizer
from koh.utils import numpy_to_dict_list
import numpy as np

# ===== Run configuration (manual adjustment) =====
SEEDS = [1, 2, 3, 4, 5]
RUN_TAG_OVERRIDE = None  # Custom run prefix (None means use timestamp)
CLEAN_PREVIOUS = False  # Whether to clean previous data
SINGLE_SEED = None  # Single run random seed (None means not used, only effective when SEEDS=[])
# Whether to clear cache files between seeds during batch runs
CLEAR_CACHE_BETWEEN_SEEDS = False  # Set to False, keep all seed results, don't clear cache

# Run parameters (override defaults in KOH_CONFIG)
MAX_HF_ITERATIONS = None  # None means use KOH_CONFIG default
N_CANDIDATES = None  # None means use KOH_CONFIG default
Q = None  # None means use KOH_CONFIG default
N_INITIAL_POINTS = None  # None means use KOH_CONFIG default

# Task basic information
TASK_NAME = "COF"
OUTPUT_ROOT = Path("outputs")
_TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M')
RUN_TAG = f"{TASK_NAME}_{_TIMESTAMP}"


def _prepare_output_dirs(task_name: str):
    base_dir = OUTPUT_ROOT / task_name
    koh_dir = base_dir / "KOH"
    for path in [base_dir, koh_dir]:
        path.mkdir(parents=True, exist_ok=True)
    return {"base": base_dir, "KOH": koh_dir}


def _clear_method_cache(cache_dir: Path, file_prefix: str):
    """Clear cache files for specified method (seed_points/history etc.)."""
    if not file_prefix:
        return
    removed = 0
    for path in cache_dir.glob(f"{file_prefix}_*"):
        if path.is_file():
            path.unlink()
            removed += 1
    if removed > 0:
        print(f"  ✓ Cleared cache files: {removed} ({cache_dir.name}, prefix={file_prefix})")


OUTPUT_DIRS = _prepare_output_dirs(TASK_NAME)
KOH_DATA_DIR = OUTPUT_DIRS["KOH"]

# Search space configuration (read from evaluator CSV file, defined here for initialization mapping)
# Note: Actual bounds will be obtained from HighFidelityBlackbox.evaluator.bounds
# COF task 14 parameters (must match PARAM_NAMES order in high_fidelity/cof.py)
SPACE_CONFIG = {
    "names": [
        "pore_diameter", "void_fraction", "surface_area", "crystal_density",
        "B", "O", "C", "H", "Si", "N", "S", "P", "halogens", "metals"
    ],
    "types": ["float"] * 14,  
    "normalize": False 
}

# Read gcmc_y range from CSV
def _get_gcmc_y_range():
    """Read gcmc_y range from COF CSV file."""
    csv_path = Path(__file__).parent / "high_fidelity" / "cof.csv"
    if not csv_path.exists():
        # If file doesn't exist, use default range
        return [0.0, 20.0]  # COF task target value range
    
    try:
        df = pd.read_csv(csv_path)
        if "gcmc_y" not in df.columns:
            return [0.0, 20.0]
        
        min_val = float(df["gcmc_y"].min())
        max_val = float(df["gcmc_y"].max())
        # Use actual range, but ensure it doesn't exceed [0, 20]
        actual_min = max(0.0, min_val)
        actual_max = min(20.0, max_val)
        return [actual_min, actual_max]
    except Exception as e:
        print(f"Warning: Unable to read gcmc_y range, using default: {e}")
        return [0.0, 20.0]

# LLM configuration
_gcmc_y_range = _get_gcmc_y_range()
LLM_CONFIG = {
    "llm_type": "intern_s1",      # Specify which LLM client to use
    "value_range": _gcmc_y_range   # Allowed value range for low-fidelity predictions (read from CSV)
}

# KOH optimization configuration
KOH_CONFIG = {
    "max_hf_iterations": 30,
    "n_candidates": 5000,
    "q": 2,
    "n_initial_points": 3,
    "mismatch_threshold": 0.75,
    "force_hf_after_n_lf": 20,  # Force one HF after 20 consecutive LFs
    "gp_training_iter": 100,
    "max_loops": 1000,
    "acquisition_type": "ucb",  # "ei" or "ucb"
    "acquisition_beta": 2.0     # UCB exploration parameter (larger beta means more exploration)
}

FIXED_INIT_POINTS = None


def _get_config_value(config_value, config_key, default_config):
    """Helper function to get value from config or default."""
    return config_value if config_value is not None else default_config[config_key]

def run_koh_optimization(seed_override: int = None, file_prefix_override: str = None):
    """Run KOH optimization.
    
    Args:
        seed_override: Override random seed (for batch runs)
        file_prefix_override: Override file prefix (for batch runs)
    """
    current_seed = seed_override if seed_override is not None else SINGLE_SEED
    current_file_prefix = file_prefix_override if file_prefix_override is not None else RUN_TAG
    
    seed_suffix = f" (Seed: {current_seed})" if current_seed is not None else ""
    print(f"\n{'='*80}")
    print(f"{'KOH Optimization - COF Task' + seed_suffix:^80}")
    print(f"{'='*80}")
    
    # Initialize components
    print("\nInitializing components...")
    
    output_dir = KOH_DATA_DIR
    print(f"  ✓ Output directory: {output_dir}")
    print(f"  ✓ Run prefix: {current_file_prefix}")
    if current_seed is not None:
        print(f"  ✓ Random seed: {current_seed} (only affects sampling, not LLM)")
    if CLEAN_PREVIOUS:
        removed = 0
        for file in output_dir.glob(f"{current_file_prefix}_*.csv"):
            file.unlink()
            removed += 1
        print(f"  ✓ Cleared old files with prefix: {removed}")
    
    # LLM client - use global cache, model only loaded once
    from API.llm_config_utils import create_llm_config_from_model_name
    
    llm_model = LLM_CONFIG.get("model", "intern-s1")
    llm_config = create_llm_config_from_model_name(
        llm_model=llm_model,
        base_config=LLM_CONFIG.copy()
    )
    
    llm_client = create_llm_client(llm_config)
    print(f"  ✓ LLM client ready ({llm_config.llm_type}, model: {llm_model})")
    print(f"  ✓ gcmc_y range: {LLM_CONFIG['value_range']}")
    
    # High-fidelity blackbox (automatically get bounds etc. from evaluator)
    hf_blackbox = HighFidelityBlackbox(
        task_name=TASK_NAME,
        csv_path=None,
        feature_names=SPACE_CONFIG["names"],
        target_col="gcmc_y"
    )
    print(f"  ✓ High-fidelity blackbox ready")
    
    # Get actual bounds from evaluator
    actual_bounds = hf_blackbox.bounds
    print(f"  ✓ Search space: {SPACE_CONFIG['names']}")
    print(f"  ✓ Number of variables: {len(SPACE_CONFIG['names'])}")
    print(f"  ✓ Variable ranges: {list(zip(SPACE_CONFIG['names'], actual_bounds))}")
    
    # Prepare KOH config (set random_seed if seed provided)
    koh_config_dict = KOH_CONFIG.copy()
    if current_seed is not None:
        koh_config_dict["random_seed"] = current_seed
    
    # KOH optimizer
    optimizer = KOHOptimizer(
        task_name=TASK_NAME,
        task_data_dir=str(KOH_DATA_DIR),
        feature_names=SPACE_CONFIG["names"],
        feature_types=SPACE_CONFIG["types"],
        bounds=actual_bounds,
        target_name="gcmc_y",
        llm_client=llm_client,
        hf_blackbox=hf_blackbox,
        llm_config=llm_config,
        koh_config=SimpleNamespace(**koh_config_dict),
        file_prefix=current_file_prefix
    )
    print(f"  ✓ KOH optimizer ready")
    
    # Run optimization
    print(f"\n{'='*80}")
    print(f"{'Starting KOH Optimization':^80}")
    print(f"{'='*80}")
    
    max_iterations = _get_config_value(MAX_HF_ITERATIONS, "max_hf_iterations", KOH_CONFIG)
    n_initial_points = _get_config_value(N_INITIAL_POINTS, "n_initial_points", KOH_CONFIG)
    q = _get_config_value(Q, "q", KOH_CONFIG)
    
    # Convert fixed initial points to dict format if provided
    fixed_initial_points_dict = None
    if FIXED_INIT_POINTS and len(FIXED_INIT_POINTS) > 0:
        # Convert list format to dict format
        fixed_array = np.array(FIXED_INIT_POINTS)
        fixed_initial_points_dict = numpy_to_dict_list(fixed_array, SPACE_CONFIG["names"])
        print(f"  ✓ Using fixed initial points: {len(fixed_initial_points_dict)} points")
    
    optimizer.run(
        max_iterations=max_iterations,
        n_initial_points=n_initial_points,
        q=q,
        fixed_initial_points=fixed_initial_points_dict
    )
    
    print(f"\n{'='*80}")
    print(f"{'🎉 KOH Optimization Complete!':^80}")
    print(f"{'='*80}")


def main():
    """Main function."""
    global RUN_TAG, SEEDS, SINGLE_SEED
    
    parser = argparse.ArgumentParser(description="COF task run script")
    parser.add_argument("--seed", type=int, help="Single run random seed (override SINGLE_SEED)")
    parser.add_argument("--seeds", type=str, help="Comma-separated random seed list, run one by one (override SEEDS)")
    args = parser.parse_args()
    
    if args.seeds:
        SEEDS = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
        SINGLE_SEED = None
    elif args.seed is not None:
        SINGLE_SEED = args.seed
    
    # If run prefix specified, override global RUN_TAG
    if RUN_TAG_OVERRIDE:
        RUN_TAG = RUN_TAG_OVERRIDE
    
    # Batch run logic
    def run_batch_seeds(run_func, prefix_template, description="", cache_dir: Path = None):
        """Generic function for batch running multiple seeds."""
        seed_list = [int(s) for s in SEEDS]
        print(f"\n{'='*80}")
        print(f"{f'Batch Run Mode ({description})':^80}" if description else f"{'Batch Run Mode':^80}")
        print(f"{'='*80}")
        print(f"Will run {len(seed_list)} random seeds in order: {seed_list}")
        print(f"Note: These random seeds only affect sampling (numpy, torch)")
        
        for idx, seed in enumerate(seed_list, 1):
            print(f"\n{'='*80}")
            print(f"{'Seed %d (%d/%d)' % (seed, idx, len(seed_list)):^80}")
            print(f"{'='*80}")
            run_func(seed_override=seed, file_prefix_override=prefix_template.format(seed=seed))
            print(f"\n✓ Seed {seed} run complete")
            
            if CLEAR_CACHE_BETWEEN_SEEDS and cache_dir is not None:
                _clear_method_cache(cache_dir, prefix_template.format(seed=seed))
        
        print(f"\n{'='*80}")
        print(f"{'🎉 All Seeds Run Complete!':^80}")
        print(f"{'='*80}")
    
    # Method mapping
    prefix_template = f"{TASK_NAME}_{_TIMESTAMP}_seed{{seed}}"
    description = "KOH"
    cache_dir = KOH_DATA_DIR
    
    if SEEDS:
        run_batch_seeds(run_koh_optimization, prefix_template, description, cache_dir)
    else:
        run_koh_optimization()


if __name__ == "__main__":
    main()

