"""KOH optimizer - main loop implementation"""
import numpy as np
import pandas as pd
from typing import List, Dict, Optional
from pathlib import Path

from .data_manager import DataManager
from .models.low_fidelity_gp import LowFidelityGP
from .models.residual_gp import ResidualGP
from .models.rho_manager import RhoManager
from .fusion import KOHFusion
from .acquisition import select_next_points_q_ei, compute_ei, compute_ucb
from .decision import MismatchDecision
from .utils import sample_candidates, numpy_to_dict_list, dict_list_to_numpy, find_high_error_points, filter_sandwich_constraints
from low_fidelity.warmup import warmup_phase

from low_fidelity.generator import LLMGenerator
from low_fidelity.predictor import LowFidelityPredictor


class KOHOptimizer:
    """KOH optimizer - complete MFBO main loop."""
    
    def __init__(
        self,
        task_name: str,
        task_data_dir: str,
        feature_names: List[str],
        feature_types: List[str],
        bounds: np.ndarray,
        target_name: str,
        llm_client,
        hf_blackbox,
        llm_config,
        koh_config,
        file_prefix: Optional[str] = None,
        objective_transform: float = 1.0
    ):
        """Initialize KOH optimizer.
        
        Args:
            task_name: Task name
            task_data_dir: Task data directory
            feature_names: List of feature names
            feature_types: List of feature types (["int", "int", ...] or ["float", ...])
            bounds: Search space bounds (numpy array, shape=(n_dims, 2))
            target_name: Target column name
            llm_client: LLM client
            hf_blackbox: High-fidelity blackbox model
            llm_config: LLM config object (SimpleNamespace)
            koh_config: KOH config object (SimpleNamespace)
        """
        self.task_name = task_name
        self.llm_client = llm_client
        self.hf_blackbox = hf_blackbox
        self.llm_config = llm_config
        self.koh_config = koh_config
        self.objective_transform = float(objective_transform)
        
        # Data manager
        self.data_manager = DataManager(
            task_data_dir=task_data_dir,
            feature_names=feature_names,
            target_name=target_name,
            file_prefix=file_prefix,
            objective_transform=self.objective_transform
        )
        
        # LLM generator
        from low_fidelity.prompt import load_prompts
        system_prompt, user_prompts = load_prompts(task_name)
        if not user_prompts:
            raise ValueError("No user prompt templates loaded")
        self.user_prompt = user_prompts[0]
        
        llm_log_path = None
        if file_prefix:
            llm_log_path = str(Path(task_data_dir) / f"{file_prefix}_llm_calls.jsonl")
        
        self.generator = LLMGenerator(
            llm_client=llm_client,
            system_prompt=system_prompt,
            value_range=getattr(llm_config, 'value_range', None),
            log_path=llm_log_path
        )
        
        self.predictor = LowFidelityPredictor(
            generator=self.generator,
            user_prompts=self.user_prompt,
            temperature=llm_config.temperature,
            top_p=llm_config.top_p,
            max_tokens=llm_config.max_tokens,
            alpha=llm_config.alpha,
            beta=llm_config.beta,
            y_transform=self.objective_transform
        )
        
        # Models (using config parameters)
        gp_training_iter = getattr(koh_config, 'gp_training_iter', 100)
        
        # Create GP models (each independently trains hyperparameters)
        bounds_array = np.array(bounds)  # Ensure numpy array
        self.lf_gp = LowFidelityGP(training_iter=gp_training_iter, bounds=bounds_array)
        self.residual_gp = ResidualGP(training_iter=gp_training_iter, bounds=bounds_array)
        self.rho_manager = RhoManager()
        self.fusion = None  # Created after initialization
        
        # Decision maker
        threshold = getattr(koh_config, 'mismatch_threshold', 0.8)
        force_hf_after_n_lf = getattr(koh_config, 'force_hf_after_n_lf', None)
        self.mismatch_decision = MismatchDecision(
            threshold=threshold, 
            consecutive_high_limit=3,
            force_hf_after_n_lf=force_hf_after_n_lf
        )
        
        
        # Optimization parameters
        self.bounds = bounds_array
        self.types = feature_types
        self.feature_names = feature_names
        
        # Result logging
        self.iteration_log = []
        
        # Candidate point cache (for seed-based sampling): {seed: candidates_array}
        self._seed_candidates_cache = {}
        
        # Save main loop random_seed (not used in warmup phase)
        self.main_random_seed = getattr(koh_config, 'random_seed', None)

        # Always recompute LF predictions for the first N loops even without HF refresh
        self.always_update_lf_loops = getattr(self.koh_config, 'always_update_lf_loops', 12)
        
        # Acquisition function configuration (default UCB)
        self.acquisition_type = getattr(self.koh_config, 'acquisition_type', 'ucb').lower()
        self.acquisition_beta = getattr(self.koh_config, 'acquisition_beta', 2.0)
    
    def run(self, max_iterations: int = 25, n_initial_points: int = 5, q: int = 2, fixed_initial_points: Optional[List[dict]] = None):
        """Run complete KOH optimization loop.
        
        Args:
            max_iterations: Maximum HF iterations (only HF branch counts)
            n_initial_points: Number of initial points
            q: Number of points selected per round
            fixed_initial_points: Fixed initial point list (if provided, skip LLM generation and use these points directly)
                                 Format: [{"feat1": val1, "feat2": val2, ...}, ...]
        """
        # Warm-up phase
        warmup_phase(
            llm_client=self.llm_client,
            hf_blackbox=self.hf_blackbox,
            data_manager=self.data_manager,
            generator=self.generator,
            user_prompt=self.user_prompt,
            task_name=self.task_name,
            feature_names=self.feature_names,
            n_initial_points=n_initial_points,
            # Warmup phase:
            # - Initial warmup point generation: fixed use temperature=0.0, top_p=0.9 in generate_initial_points_with_llm
            # - LOO and LHS batch LF prediction: use temperature=0.7, top_p=0.9 passed below
            temperature=0.7,
            top_p=0.9,
            max_tokens=self.llm_config.max_tokens,
            fixed_initial_points=fixed_initial_points,
            y_transform=self.objective_transform
        )
        
        # === After Warmup completion, set main loop random seed (if configured), completely decoupled from Warm-up ===
        if self.main_random_seed is not None:
            import torch
            np.random.seed(self.main_random_seed)
            torch.manual_seed(self.main_random_seed)
        
        # === Main loop ===
        iteration = 0  # HF iteration count (only increments on HF branch)
        loop_count = 0  # Loop count (prevent infinite loop)
        max_loops_config = getattr(self.koh_config, 'max_loops', None)
        if max_loops_config is not None:
            max_loops = int(max_loops_config)
        else:
            loop_multiplier = getattr(self.koh_config, 'max_loop_multiplier', 3)
            max_loops = max_iterations * loop_multiplier  # Maximum loop count (allows multiple LF branches)
        
        while iteration < max_iterations and loop_count < max_loops:
            loop_count += 1

            try:
                # Step 1: Train models
                # First round or when previous round was HF branch, need to recompute all LF predictions
                force_recompute = (loop_count == 1) or (hasattr(self, '_last_was_hf') and self._last_was_hf)
                self._train_models(force_recompute_lf=force_recompute, loop_count=loop_count)
                
                # Step 2: Sample candidate points (using current seed + filter historical points)
                n_candidates = getattr(self.koh_config, 'n_candidates', 5000)
                candidates = self._sample_candidates(n_samples=n_candidates, loop_count=loop_count)
                
                # Step 3: KOH posterior prediction
                mu_H, sigma2_H, mu_delta, sigma2_delta = self._koh_posterior_predict(candidates)
                
                # Step 4: q-EI optimization (standard greedy: update GP with fantasy after each point selection)
                y_best = self._get_best_y()
                selected_indices, ei_values = self._select_q_points_greedy_with_fantasy(
                    candidates=candidates,
                    q=q,
                    y_best=y_best
                )
                
                x_next_list = [candidates[idx] for idx in selected_indices]
                
                # Step 5: Mismatch judgment (first layer decision - gate 1)
                do_hf, r_max, r_values = self.mismatch_decision.decide(
                    selected_indices, sigma2_delta, sigma2_H
                )
                
                # Step 6: Branch decision (only based on gate 1: mismatch ratio)
                if do_hf:
                    # Gate 1 decides HF, directly execute HF branch
                    iteration += 1  # Only HF branch increments iteration
                    self._high_fidelity_branch(x_next_list, iteration)
                    self._last_was_hf = True  # Mark previous was HF branch
                else:
                    # Gate 1 decides LF, directly execute LF branch
                    self._low_fidelity_branch(x_next_list, iteration)
                    self._last_was_hf = False
                
                # Log iteration
                self._log_iteration(iteration, x_next_list, do_hf, r_max, y_best)
                
            except Exception as e:
                import traceback
                traceback.print_exc()
                continue

        self._print_summary()
    
    def _train_models(self, force_recompute_lf: bool = True, loop_count: int = 0):
        """Train all models (low-fidelity GP, residual GP, ρ)."""
        update_lf_this_loop = force_recompute_lf or (loop_count <= self.always_update_lf_loops)

        # 1. Recompute all LF predictions (if needed)
        if update_lf_this_loop:
            # Step 1.1: Only recompute LOO predictions for new history points (performance optimization)
            self.data_manager.recompute_all_lf_predictions(self.predictor, iteration=None, only_new=True)
            
            # Step 1.2: Only recompute LF predictions for non-history points when history has new points (performance optimization)
            # This allows leveraging new history data to improve prediction performance for non-history points
            self.data_manager.recompute_non_history_lf_predictions(self.predictor, iteration=None, only_if_history_changed=True)
            
            self.data_manager.save_all()
        
        # 2. Train low-fidelity GP (using all LF prediction data: LOO predictions for history points + full history predictions for non-history points)
        X_lf, mu_LF, sigma2_LF = self.data_manager.get_lf_training_data()
        
        if len(X_lf) > 0:
            self.lf_gp.fit(X_lf, mu_LF, sigma2_LF)
        else:
            pass
        
        # 3. Train residual GP and ρ (using all history data)
        X_hist, y_H = self.data_manager.get_history_points()
        
        if len(X_hist) == 0:
            raise ValueError("No HF point data, cannot train residual GP")
        
        # Get LF predictions for all history points (for computing residuals)
        mu_LF_hf = np.zeros(len(X_hist))
        sigma2_LF_hf = np.zeros(len(X_hist))
        
        for i, x_array in enumerate(X_hist):
            x_dict = {name: val for name, val in zip(self.feature_names, x_array)}
            
            # Get prediction result for this point from lf_predictions
            existing_pred = self.data_manager._get_existing_lf_prediction(x_dict)
            if existing_pred is not None:
                mu_LF_hf[i] = existing_pred["mu_LF"]
                sigma2_LF_hf[i] = existing_pred["sigma2_LF"]
            else:
                # If not found, perform real-time LOO prediction
                try:
                    history_exclude = self.data_manager.get_history_exclude_point(x_dict)
                    mu_LF, sigma2_LF, _ = self.predictor.predict(x_dict, history_exclude)
                    mu_LF_hf[i] = mu_LF
                    sigma2_LF_hf[i] = sigma2_LF
                    self.data_manager.add_lf_prediction(x_dict, mu_LF, sigma2_LF, iteration=loop_count)
                except Exception as e:
                    mu_LF_hf[i] = np.nan
                    sigma2_LF_hf[i] = np.nan
        
        # Filter NaN values
        valid_mask = ~np.isnan(mu_LF_hf)
        mu_LF_hf = mu_LF_hf[valid_mask]
        sigma2_LF_hf = sigma2_LF_hf[valid_mask]
        X_hf = X_hist[valid_mask]
        y_H_valid = y_H[valid_mask]
        
        if len(mu_LF_hf) == 0:
            raise ValueError("All HF point LF predictions failed")
        
        # 4. Update ρ (least squares scaling) and compute residuals
        rho_value = self.rho_manager.compute_rho(y_H_valid, mu_LF_hf, iteration=loop_count)
        
        residuals = y_H_valid - rho_value * mu_LF_hf
        
        # Train residual GP, independently train all hyperparameters
        self.residual_gp.fit(X_hf, residuals)
        
        # 5. Create fusion model
        self.fusion = KOHFusion(self.lf_gp, self.residual_gp, self.rho_manager)
    
    def _select_q_points_greedy_with_fantasy(self, candidates: np.ndarray, q: int, y_best: float):
        """Standard greedy q-acquisition function: after selecting each point, sample fantasy observation from HF posterior and update temporary residual GP, then select next.
        Does not modify real models, only updates on temporary model.
        Supports both EI and UCB acquisition functions.
        Returns (selected_indices, initial_acq_values).
        """
        # Initial acquisition function values based on current fusion model (for return analysis)
        mu_H_init, sigma2_H_init = self.fusion.predict(candidates)
        sigma_H_init = np.sqrt(sigma2_H_init)
        
        if self.acquisition_type == 'ucb':
            acq_init = compute_ucb(mu_H_init, sigma_H_init, self.acquisition_beta)
        else:
            acq_init = compute_ei(mu_H_init, sigma_H_init, y_best, xi=0.02)

        n = len(candidates)
        remaining_mask = np.ones(n, dtype=bool)
        selected_indices: List[int] = []

        # Temporary model: LF fixed, residual GP will be gradually updated with fantasy observations
        temp_residual_gp = self.residual_gp

        for _ in range(min(q, n)):
            # Compute acquisition function values on remaining candidates using current temporary fusion model
            fusion_temp = KOHFusion(self.lf_gp, temp_residual_gp, self.rho_manager)
            mu_H_remain, sigma2_H_remain = fusion_temp.predict(candidates[remaining_mask])
            sigma_H_remain = np.sqrt(sigma2_H_remain)
            
            if self.acquisition_type == 'ucb':
                acq_values = compute_ucb(mu_H_remain, sigma_H_remain, self.acquisition_beta)
            else:
                acq_values = compute_ei(mu_H_remain, sigma_H_remain, y_best, xi=0.02)

            # Select local optimum
            local_best = int(np.argmax(acq_values))
            global_indices = np.where(remaining_mask)[0]
            best_idx = int(global_indices[local_best])
            selected_indices.append(best_idx)

            # Fantasy update: sample y from HF posterior for this point, and update residual GP with δ = y - μ_LF
            x_sel = candidates[best_idx]

            mu_H_x = float(mu_H_remain[local_best])
            sigma2_H_x = float(sigma2_H_remain[local_best])
            if sigma2_H_x > 1e-12:
                y_fantasy = np.random.normal(loc=mu_H_x, scale=np.sqrt(sigma2_H_x))
            else:
                y_fantasy = mu_H_x

            mu_LF_x, _ = self.lf_gp.predict_with_variance(x_sel.reshape(1, -1))
            rho_current = self.rho_manager.get_rho()
            residual_delta = float(y_fantasy - rho_current * mu_LF_x[0])

            # Construct new training set and fit temporary residual GP
            X_old = temp_residual_gp.X_train
            r_old = temp_residual_gp.residuals_train
            X_new = np.vstack([X_old, x_sel.reshape(1, -1)])
            r_new = np.hstack([r_old, [residual_delta]])

            temp_residual_gp_new = ResidualGP(training_iter=temp_residual_gp.training_iter, bounds=temp_residual_gp.bounds)
            temp_residual_gp_new.fit(X_new, r_new)
            temp_residual_gp = temp_residual_gp_new

            # Remove this point from remaining set
            remaining_mask[best_idx] = False

            if len(selected_indices) >= q:
                break

        return selected_indices, acq_init

    def _sample_candidates(self, n_samples: int, loop_count: int = 1) -> np.ndarray:
        """Sample candidate points using current seed and filter historical points.
        
        Args:
            n_samples: Number of points to sample
            loop_count: Current loop count (starting from 1)
        
        Returns:
            Deduplicated candidate point array
        """
        base_seed = self.main_random_seed or 0
        seed_key = int(base_seed) + loop_count
        
        # Sample using current seed only (with caching)
        if seed_key not in self._seed_candidates_cache:
            # Use this seed to sample
            try:
                np.random.seed(seed_key)
            except Exception:
                pass
            
            # Sample candidate points using current seed
            candidates = sample_candidates(self.bounds, n_samples, self.types)
            
            # Cache result
            self._seed_candidates_cache[seed_key] = candidates
        else:
            # Get from cache
            candidates = self._seed_candidates_cache[seed_key]
        
        # Deduplicate based on feature values
        # Convert floats to integers for deduplication (for integer types) or keep floats (for float types)
        if len(candidates) > 0:
            # Use feature value tuple as unique identifier
            unique_points = {}
            for point in candidates:
                # Determine deduplication precision based on type
                if self.types and all(t == "int" for t in self.types):
                    # Integer type: use integer tuple directly
                    key = tuple(int(v) for v in point)
                else:
                    # Float type: use tuple rounded to 6 decimal places (avoid floating point errors)
                    key = tuple(round(float(v), 6) for v in point)
                
                if key not in unique_points:
                    unique_points[key] = point
            
            candidates = np.array(list(unique_points.values()))
        
        # Filter historical evaluated points (avoid duplicate HF evaluations)
        if len(self.data_manager.history_df) > 0:
            # IMPORTANT: Cannot force int() for float features, otherwise PCE10-like [0,1] variables will all become 0, causing filtering to fail
            if self.types and all(t == "int" for t in self.types):
                key_fn = lambda row: tuple(int(v) for v in row)
            else:
                # Consistent with deduplication logic: construct key for floats with 6 decimal places, avoid floating point errors causing "seemingly identical points" to be missed
                key_fn = lambda row: tuple(round(float(v), 6) for v in row)
            
            evaluated = set(key_fn(row) for row in self.data_manager.history_df[self.feature_names].values)
            mask = np.array([key_fn(row) not in evaluated for row in candidates], dtype=bool)
            if mask.any():
                candidates = candidates[mask]
        
        # Apply task-specific hard constraint filtering (e.g., nutritional constraints for Sandwich task)
        if self.task_name == "Sandwich":
            original_count = len(candidates)
            candidates = filter_sandwich_constraints(candidates, self.feature_names)
            filtered_count = len(candidates)
        
        return candidates
    
    def _koh_posterior_predict(self, candidates: np.ndarray):
        """KOH posterior prediction.
        
        Returns:
            (mu_H, sigma2_H, mu_delta, sigma2_delta)
        """
        # Full posterior prediction
        mu_H, sigma2_H = self.fusion.predict(candidates)
        
        # Get residual prediction separately (for computing mismatch ratio)
        mu_delta, sigma2_delta = self.fusion.predict_residual_only(candidates)
        
        return mu_H, sigma2_H, mu_delta, sigma2_delta
    
    def _get_best_y(self) -> float:
        """Get current best value."""
        if len(self.data_manager.history_df) == 0:
            return 0.0
        return float(self.data_manager.history_df[self.data_manager.target_name].max())
    
    def _high_fidelity_branch(self, x_next_list: List[np.ndarray], iteration: int):
        """High-fidelity branch."""
        # 1. Perform HF experiments
        X_dict = numpy_to_dict_list(np.array(x_next_list), self.feature_names)
        
        for i, x_dict in enumerate(X_dict, 1):
            y_H = self.hf_blackbox.evaluate(x_dict)
            self.data_manager.add_hf_experiment(x_dict, y_H, iteration)
        
        # 2. Save data (simplified version: HF branch only saves HF experiment results)
        # In next round _train_models will recompute LF predictions for all HF points
        
        # Save data
        self.data_manager.save_all()
    
    def _low_fidelity_branch(self, x_next_list: List[np.ndarray], iteration: int):
        """Low-fidelity branch."""
        # 1. ICL prediction (full history, batch)
        X_dict = numpy_to_dict_list(np.array(x_next_list), self.feature_names)
        history = self.data_manager.get_history_data()
        
        # Dynamically adjust batch_size based on task dimension: high-dimensional tasks need smaller batch_size to avoid token limits
        # COF(14 dim): batch_size=10, CB(4 dim): batch_size=20, FeCr(3 dim): batch_size=20
        input_dim = len(self.feature_names)
        if input_dim >= 10:
            batch_size = 10  # High-dimensional tasks (e.g., COF)
        elif input_dim >= 5:
            batch_size = 15  # Medium-dimensional tasks
        else:
            batch_size = 20  # Low-dimensional tasks (e.g., FeCr, CB)
        
        mu_LF_batch, sigma2_LF_batch = self.predictor.predict_batch(
            X_dict, history, batch_size=batch_size
        )
        
        # 2. Add to lf_predictions
        for i, (x, mu, sigma2) in enumerate(zip(X_dict, mu_LF_batch, sigma2_LF_batch), 1):
            if not np.isnan(mu):
                self.data_manager.add_lf_prediction(x, mu, sigma2, iteration=iteration)
            else:
                pass
        
        # 3. Incrementally update low-fidelity GP
        X_lf, mu_LF, sigma2_LF = self.data_manager.get_lf_training_data()
        
        if len(X_lf) > 0:
            self.lf_gp.fit(X_lf, mu_LF, sigma2_LF)
            
            self.fusion = KOHFusion(self.lf_gp, self.residual_gp, self.rho_manager)
        else:
            pass
        
        # Save data
        self.data_manager.save_all()
    
    def _log_iteration(self, iteration: int, x_next_list: List[np.ndarray], 
                      do_hf: bool, r_max: float, y_best: float):
        """Log iteration information."""
        log_entry = {
            "iteration": iteration,
            "selected_points": [x.tolist() for x in x_next_list],
            "do_hf": do_hf,
            "r_max": r_max,
            "y_best": y_best,
            "n_history": len(self.data_manager.history_df),
            "n_lf_predictions": len(self.data_manager.lf_predictions_df)
        }
        
        self.iteration_log.append(log_entry)
    
    def _get_observation_noise(self, x_dict: dict) -> float:
        """Get observation noise for specified point (if exists).
        
        Args:
            x_dict: Dictionary representation of point
            
        Returns:
            Observation noise value, returns NaN if not exists
        """
        try:
            # Check if point is in LF prediction data
            if len(self.data_manager.lf_predictions_df) == 0:
                return float('nan')
            
            # Find matching point
            mask = pd.Series([True] * len(self.data_manager.lf_predictions_df))
            for feat in self.feature_names:
                if feat in x_dict:
                    mask &= (self.data_manager.lf_predictions_df[feat] == x_dict[feat])
            
            if mask.any():
                # Return observation noise for this point
                return float(self.data_manager.lf_predictions_df.loc[mask, "sigma2_LF"].iloc[0])
            else:
                return float('nan')
        except:
            return float('nan')
    
    def _print_summary(self):
        """Print optimization summary."""
        # Keep silent by default
        # Find best point
        best_idx = self.data_manager.history_df[self.data_manager.target_name].idxmax()
        best_row = self.data_manager.history_df.iloc[best_idx]
        best_x = {feat: best_row[feat] for feat in self.feature_names}
        return


