import sys
import argparse
from pathlib import Path
from datetime import datetime
from types import SimpleNamespace
import pandas as pd

# Add project root directory to path
sys.path.insert(0, str(Path(__file__).parent.absolute()))

from API.llm_clients import create_llm_client
from API.llm_config_utils import create_llm_config_from_model_name
from high_fidelity import HighFidelityBlackbox
from koh.optimizer import KOHOptimizer
from koh.utils import numpy_to_dict_list
import numpy as np

# ===== Run configuration (manual adjustment) =====
SEEDS = [1, 2, 3, 4, 5]
RUN_TAG_OVERRIDE = None  # Custom run prefix (None means use timestamp)
CLEAN_PREVIOUS = False  # Whether to clean previous data
SINGLE_SEED = None  # Single run random seed (None means not used, only effective when SEEDS=[])
# Whether to clear cache files between seeds during batch runs
CLEAR_CACHE_BETWEEN_SEEDS = False  # Set to False, keep all seed results, don't clear cache

# Run parameters (override defaults in KOH_CONFIG)
MAX_HF_ITERATIONS = None  # None means use KOH_CONFIG default
N_CANDIDATES = None  # None means use KOH_CONFIG default
Q = None  # None means use KOH_CONFIG default
N_INITIAL_POINTS = None  # None means use KOH_CONFIG default

# Task basic information
TASK_NAME = "PCE10"
OUTPUT_ROOT = Path("outputs")
_TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M')
RUN_TAG = f"{TASK_NAME}_{_TIMESTAMP}"


def _prepare_output_dirs(task_name: str):
    base_dir = OUTPUT_ROOT / task_name
    koh_dir = base_dir / "KOH"
    for path in [base_dir, koh_dir]:
        path.mkdir(parents=True, exist_ok=True)
    return {"base": base_dir, "KOH": koh_dir}


def _clear_method_cache(cache_dir: Path, file_prefix: str):
    """Clear cache files for specified method (seed_points/history etc.)."""
    if not file_prefix:
        return
    removed = 0
    for path in cache_dir.glob(f"{file_prefix}_*"):
        if path.is_file():
            path.unlink()
            removed += 1
    if removed > 0:
        print(f"  ✓ Cleared cache files: {removed} ({cache_dir.name}, prefix={file_prefix})")


OUTPUT_DIRS = _prepare_output_dirs(TASK_NAME)
KOH_DATA_DIR = OUTPUT_DIRS["KOH"]

# Search space configuration (read from evaluator CSV file, defined here for initialization mapping)
# Note: Actual bounds will be obtained from HighFidelityBlackbox.evaluator.bounds
# PCE10 task 4 parameters (must match PARAM_NAMES order in high_fidelity/pce10.py)
SPACE_CONFIG = {
    "names": ["PCE10", "P3HT", "PCBM", "olDTBR"],
    "types": ["float"] * 4,  # All parameters are real numbers (weight fractions)
    "normalize": False  # Use original weight fractions, don't normalize
}

# Read Degradation range from CSV
def _get_degradation_range():
    """Read Degradation range from PCE10 CSV file."""
    # Use fixed range [0.0, 0.75], don't depend on CSV actual range
    # This allows LLM to predict values beyond historical data range, consistent with prompt range limits
    return [0.0, 0.75]

# LLM configuration
_degradation_range = _get_degradation_range()
LLM_CONFIG = {
    "llm_type": "intern_s1",      # Specify which LLM client to use
    "value_range": _degradation_range,  # Allowed value range for low-fidelity predictions (read from CSV)
    # PCE10 is minimization problem: optimizer internally uses "maximize(-Degradation)" space
    # - LLM still outputs positive Degradation in original space
    # - Before entering optimizer/GP training, uniformly multiply by objective_transform (-1.0) to transform to internal space
    "objective_transform": -1.0
}

# KOH optimization configuration
KOH_CONFIG = {
    "max_hf_iterations": 30,
    "n_candidates": 5000,
    "q": 2,
    "n_initial_points": 3,
    "mismatch_threshold": 0.75,
    "force_hf_after_n_lf": 20,  # Force one HF after 20 consecutive LFs
    "gp_training_iter": 100,
    "max_loops": 5000,
    "acquisition_type": "ucb",  # "ei" or "ucb"
    "acquisition_beta": 2.0     # UCB exploration parameter (larger beta means more exploration)
}

FIXED_INIT_POINTS = None

# ============================================================================
# Minimization problem wrapper
# ============================================================================

class MinimizationBlackbox:
    """Minimization problem wrapper - converts minimization to maximization.
    
    By taking negative values, converts minimization problem to maximization problem, 
    allowing use of existing maximization optimizer.
    When saving results, need to convert negative values back to original values.
    """
    def __init__(self, hf_blackbox: HighFidelityBlackbox):
        """Initialize minimization wrapper.
        
        Args:
            hf_blackbox: Original high-fidelity blackbox (returns original target value, smaller is better)
        """
        self.hf_blackbox = hf_blackbox
        self.task_name = hf_blackbox.task_name
        self.feature_names = hf_blackbox.feature_names
        self.target_col = hf_blackbox.target_col
        self.bounds = hf_blackbox.bounds
        self.X = hf_blackbox.X
        self.y = hf_blackbox.y
    
    def evaluate(self, x: dict) -> float:
        """Evaluate input point, return negative value (for maximization optimization).
        
        Args:
            x: Input point dictionary
            
        Returns:
            Negative target value (for maximization optimization)
        """
        original_value = self.hf_blackbox.evaluate(x)
        # Return negative value, so system will maximize this value (i.e., minimize original value)
        return -original_value


def _get_config_value(config_value, config_key, default_config):
    """Helper function to get value from config or default."""
    return config_value if config_value is not None else default_config[config_key]


def run_koh_optimization(seed_override: int = None, file_prefix_override: str = None):
    """Run KOH optimization.
    
    Args:
        seed_override: Override random seed (for batch runs)
        file_prefix_override: Override file prefix (for batch runs)
    """
    current_seed = seed_override if seed_override is not None else SINGLE_SEED
    current_file_prefix = file_prefix_override if file_prefix_override is not None else RUN_TAG
    
    seed_suffix = f" (Seed: {current_seed})" if current_seed is not None else ""
    print(f"\n{'='*80}")
    print(f"{'KOH Optimization - PCE10 Task' + seed_suffix:^80}")
    print(f"{'='*80}")
    
    # Initialize components
    print("\nInitializing components...")
    
    output_dir = KOH_DATA_DIR
    print(f"  ✓ Output directory: {output_dir}")
    print(f"  ✓ Run prefix: {current_file_prefix}")
    if current_seed is not None:
        print(f"  ✓ Random seed: {current_seed} (only affects sampling, not LLM)")
    if CLEAN_PREVIOUS:
        removed = 0
        for file in output_dir.glob(f"{current_file_prefix}_*.csv"):
            file.unlink()
            removed += 1
        print(f"  ✓ Cleared old files with prefix: {removed}")
    
    # LLM client - use global cache, model only loaded once
    llm_model = LLM_CONFIG.get("model", "intern-s1")
    llm_config = create_llm_config_from_model_name(
        llm_model=llm_model,
        base_config=LLM_CONFIG.copy()
    )
    
    llm_client = create_llm_client(llm_config)
    print(f"  ✓ LLM client ready ({llm_config.llm_type}, model: {llm_model})")
    print(f"  ✓ Degradation range: {LLM_CONFIG['value_range']}")
    print(f"  ⚠ Note: PCE10 is minimization problem (lower Degradation is better), automatically converted to maximization problem")
    print(f"  ✓ objective_transform: {getattr(llm_config, 'objective_transform', 1.0)} (internal optimization target space)")
    
    # High-fidelity blackbox (automatically get bounds etc. from evaluator)
    hf_blackbox_original = HighFidelityBlackbox(
        task_name=TASK_NAME,
        csv_path=None,
        feature_names=SPACE_CONFIG["names"],
        target_col="Degradation"
    )
    print(f"  ✓ High-fidelity blackbox ready")
    
    # Wrap as minimization problem (take negative)
    hf_blackbox = MinimizationBlackbox(hf_blackbox_original)
    
    # Get actual bounds from evaluator
    actual_bounds = hf_blackbox.bounds
    print(f"  ✓ Search space: {SPACE_CONFIG['names']}")
    print(f"  ✓ Number of variables: {len(SPACE_CONFIG['names'])}")
    print(f"  ✓ Variable ranges: {list(zip(SPACE_CONFIG['names'], actual_bounds))}")
    
    # Prepare KOH config (set random_seed if seed provided)
    koh_config_dict = KOH_CONFIG.copy()
    if current_seed is not None:
        koh_config_dict["random_seed"] = current_seed
    
    # KOH optimizer (note: target_name uses negative value, but will be converted back when saving)
    optimizer = KOHOptimizer(
        task_name=TASK_NAME,
        task_data_dir=str(KOH_DATA_DIR),
        feature_names=SPACE_CONFIG["names"],
        feature_types=SPACE_CONFIG["types"],
        bounds=actual_bounds,
        target_name="Degradation",  # Note: although internally uses negative value, target_name remains original
        llm_client=llm_client,
        hf_blackbox=hf_blackbox,  # Use wrapped blackbox (returns negative value)
        llm_config=llm_config,
        koh_config=SimpleNamespace(**koh_config_dict),
        file_prefix=current_file_prefix,
        objective_transform=getattr(llm_config, "objective_transform", 1.0)
    )
    print(f"  ✓ KOH optimizer ready")
    
    # Run optimization
    print(f"\n{'='*80}")
    print(f"{'Starting KOH Optimization':^80}")
    print(f"{'='*80}")
    
    max_iterations = _get_config_value(MAX_HF_ITERATIONS, "max_hf_iterations", KOH_CONFIG)
    n_initial_points = _get_config_value(N_INITIAL_POINTS, "n_initial_points", KOH_CONFIG)
    q = _get_config_value(Q, "q", KOH_CONFIG)
    
    # Convert fixed initial points to dict format if provided
    fixed_initial_points_dict = None
    if FIXED_INIT_POINTS and len(FIXED_INIT_POINTS) > 0:
        # Convert list format to dict format
        fixed_array = np.array(FIXED_INIT_POINTS)
        fixed_initial_points_dict = numpy_to_dict_list(fixed_array, SPACE_CONFIG["names"])
        print(f"  ✓ Using fixed initial points: {len(fixed_initial_points_dict)} points")
    
    optimizer.run(
        max_iterations=max_iterations,
        n_initial_points=n_initial_points,
        q=q,
        fixed_initial_points=fixed_initial_points_dict
    )
    
    # Convert results: convert negative values back to original values
    print(f"\n{'='*80}")
    print(f"{'Converting Results (Minimization Problem)':^80}")
    print(f"{'='*80}")
    _convert_results_to_minimization(KOH_DATA_DIR, current_file_prefix)
    
    print(f"\n{'='*80}")
    print(f"{'KOH Optimization Complete':^80}")
    print(f"{'='*80}")


def _convert_results_to_minimization(data_dir: Path, file_prefix: str):
    """Convert negative values in result files back to original values (minimization problem).
    
    Since optimizer internally uses negative values for maximization, saved results are also negative.
    Need to convert all target values in CSV files back to original values.
    
    Note: For minimization problem, best_objective should be increasingly smaller values (better),
    but since we used negative values, best_objective is actually increasingly larger negative values.
    After conversion, best_objective should be increasingly smaller positive values (better).
    """
    files_to_convert = [
        f"{file_prefix}_hf_predictions.csv",
        f"{file_prefix}_lf_predictions.csv",
        f"{file_prefix}_history.csv",
        f"{file_prefix}_seed_points.csv"
    ]
    
    for filename in files_to_convert:
        filepath = data_dir / filename
        if not filepath.exists():
            continue
        
        try:
            df = pd.read_csv(filepath)
            converted = False
            
            # Convert Degradation column (if exists)
            if "Degradation" in df.columns:
                df["Degradation"] = -df["Degradation"]
                converted = True
                print(f"  ✓ Converted {filename}: Degradation column negated ({len(df)} rows)")

            # Convert mu_LF column (if exists): LF predictions are negative in internal space, restore to original Degradation when outputting
            if "mu_LF" in df.columns:
                df["mu_LF"] = -df["mu_LF"]
                converted = True
                print(f"  ✓ Converted {filename}: mu_LF column negated ({len(df)} rows)")
            
            # Convert best_objective column (if exists)
            # Note: For minimization problem, best_objective should be increasingly smaller values
            # Since negative values were used, max(negative) corresponds to min(original), need to recalculate after conversion
            if "best_objective" in df.columns:
                # First convert all values to original values
                df["best_objective"] = -df["best_objective"]
                # For minimization problem, best_objective should be cumulative minimum
                # Recalculate best_objective (cumulative minimum)
                if "Degradation" in df.columns:
                    # Use Degradation column to recalculate best_objective (cumulative minimum)
                    df["best_objective"] = df["Degradation"].cummin()
                converted = True
                print(f"  ✓ Converted {filename}: best_objective column recalculated as cumulative minimum")
            
            # Save converted file
            if converted:
                df.to_csv(filepath, index=False)
            
        except Exception as e:
            print(f"  ⚠ Conversion {filename} failed: {e}")
            import traceback
            traceback.print_exc()


def main():
    """Main function."""
    global RUN_TAG, SEEDS, SINGLE_SEED
    
    parser = argparse.ArgumentParser(description="PCE10 task run script")
    parser.add_argument("--seed", type=int, help="Single run random seed (override SINGLE_SEED)")
    parser.add_argument("--seeds", type=str, help="Comma-separated random seed list, run one by one (override SEEDS)")
    args = parser.parse_args()
    
    if args.seeds:
        SEEDS = [int(s.strip()) for s in args.seeds.split(",") if s.strip()]
        SINGLE_SEED = None
    elif args.seed is not None:
        SINGLE_SEED = args.seed
    
    # If run prefix specified, override global RUN_TAG
    if RUN_TAG_OVERRIDE:
        RUN_TAG = RUN_TAG_OVERRIDE
    
    # Batch run logic
    def run_batch_seeds(run_func, prefix_template, description="", cache_dir: Path = None):
        """Generic function for batch running multiple seeds."""
        seed_list = [int(s) for s in SEEDS]
        print(f"\n{'='*80}")
        print(f"{f'Batch Run Mode ({description})':^80}" if description else f"{'Batch Run Mode':^80}")
        print(f"{'='*80}")
        print(f"Will run {len(seed_list)} random seeds in order: {seed_list}")
        print(f"Note: These random seeds only affect sampling (numpy, torch)")
        
        for idx, seed in enumerate(seed_list, 1):
            print(f"\n{'='*80}")
            print(f"{'Seed %d (%d/%d)' % (seed, idx, len(seed_list)):^80}")
            print(f"{'='*80}")
            run_func(seed_override=seed, file_prefix_override=prefix_template.format(seed=seed))
            print(f"\n✓ Seed {seed} run complete")
            
            if CLEAR_CACHE_BETWEEN_SEEDS and cache_dir is not None:
                _clear_method_cache(cache_dir, prefix_template.format(seed=seed))
        
        print(f"\n{'='*80}")
        print(f"{'🎉 All Seeds Run Complete!':^80}")
        print(f"{'='*80}")
    
    # Method mapping
    prefix_template = f"{TASK_NAME}_{_TIMESTAMP}_seed{{seed}}"
    description = "KOH"
    cache_dir = KOH_DATA_DIR
    
    if SEEDS:
        run_batch_seeds(run_koh_optimization, prefix_template, description, cache_dir)
    else:
        run_koh_optimization()


if __name__ == "__main__":
    main()
