"""Low-fidelity predictor."""
import numpy as np
from typing import List, Dict, Tuple


class LowFidelityPredictor:
    """Low-fidelity predictor: single prompt + multiple seeds with history-aware fallback."""
    
    def __init__(
        self,
        generator,
        user_prompts,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048,
        alpha: float = 1.0,
        beta: float = 0.0,
        y_transform: float = 1.0,
        *,
        verbose: bool = False,
    ):
        self.generator = generator
        self.verbose = bool(verbose)
        
        if isinstance(user_prompts, (list, tuple)):
            if len(user_prompts) == 0:
                raise ValueError("user_prompts cannot be empty")
            self.user_prompt = user_prompts[0]
        else:
            self.user_prompt = user_prompts
        
        # Fixed seed sequence 1–20 for deterministic retries
        self.seed_candidates = list(range(1, 21))
        self.temperature = temperature
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.alpha = alpha
        self.beta = beta
        self.y_transform = float(y_transform)
    
    def predict(self, x: dict, history: List[Dict] = None) -> Tuple[float, float, Dict]:
        """Predict a single point using multiple seeds in sequence.

        Args:
            x: Input point, e.g. {"fen": 100, "crn": 125, "hn": 260}.
            history: History list [{"x": {...}, "y": 0.85}, ...].

        Returns:
            (mean, variance, details) where:
            - mean: first successful prediction (after y_transform).
            - variance: always 0 (no observation noise).
            - details: diagnostic information about the call.
        """
        history_filtered = self._exclude_current_point(history, x)
        success_value = None
        success_seed = None
        fallback_used = False
        
        for seed in self.seed_candidates:
            try:
                value = self.generator.generate_single(
                    self.user_prompt,
                    x,
                    history_filtered,
                    seed=seed,
                    temperature=self.temperature,
                    top_p=self.top_p,
                    max_tokens=self.max_tokens
                )
                success_value = value
                success_seed = seed
                break
            except Exception as e:
                if self.verbose:
                    pass
        if success_value is None:
            success_value = self._fallback_from_history(x, history_filtered)
            fallback_used = True
            if self.verbose:
                pass
        
        mean = float(success_value) * self.y_transform
        variance = 0.0  # no observation noise, always 0
        
        details = {
            "success_value": mean,
            "raw_value": float(success_value),
            "y_transform": self.y_transform,
            "seed_used": success_seed,
            "seed_candidates": self.seed_candidates,
            "fallback_used": fallback_used,
            "input_x": x,
            "history_size": len(history_filtered),
            "call_log": self.generator.call_log
        }
        
        return mean, variance, details
    
    def predict_batch(
        self,
        X_batch: List[dict],
        history: List[Dict] = None,
        batch_size: int = 20
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Predict a batch of points using true batched LLM calls.

        Args:
            X_batch: List of input points.
            history: History data list.
            batch_size: Number of points per LLM call (controls token budget).

        Returns:
            (means, variances) arrays of shape (n_points,), variances are all 0.
        """
        n_points = len(X_batch)
        if n_points == 0:
            return np.array([]), np.array([])
        
        all_means = []
        all_variances = []
        
        # For batch prediction, use the full history (batch usually contains new points)
        history_filtered = history
        
        # Process in chunks
        for batch_start in range(0, n_points, batch_size):
            batch_end = min(batch_start + batch_size, n_points)
            X_batch_chunk = X_batch[batch_start:batch_end]
            
            try:
                # Use true batched prediction (one LLM call for multiple points).
                # Reuse the single-point seed sequence (1–20) for determinism.
                # For each batch we start with the first seed and retry on failure.
                values = None
                for seed in self.seed_candidates:
                    try:
                        values = self.generator.generate_batch_multi_points(
                            user_prompt_template=self.user_prompt,
                            X_batch=X_batch_chunk,
                            history=history_filtered,
                            seed=seed,  # use fixed seed sequence for determinism
                            temperature=self.temperature,
                            top_p=self.top_p,
                            max_tokens=self.max_tokens
                        )
                        break  # break on first success
                    except Exception as e:
                        if self.verbose:
                            pass
                        continue
                
                if values is None:
                    raise ValueError(f"All seeds failed, cannot batch-predict {len(X_batch_chunk)} points")
                
                # Process batch results
                for value in values:
                    if value is not None:
                        all_means.append(float(value) * self.y_transform)
                        all_variances.append(0.0)  # no observation noise
                    else:
                        all_means.append(np.nan)
                        all_variances.append(np.nan)
                
                if self.verbose:
                    pass
                
            except Exception as e:
                if self.verbose:
                    pass
                for x in X_batch_chunk:
                    try:
                        mean, variance, _ = self.predict(x, history)
                        all_means.append(mean)
                        all_variances.append(variance)
                    except Exception as e2:
                        if self.verbose:
                            pass
                        all_means.append(np.nan)
                        all_variances.append(np.nan)
        
        return np.array(all_means), np.array(all_variances)

    @staticmethod
    def _is_same_point(a: dict, b: dict) -> bool:
        if a is None or b is None:
            return False
        a_keys = set(a.keys())
        b_keys = set(b.keys())
        if a_keys != b_keys:
            return False
        for key in a_keys:
            if a[key] != b[key]:
                return False
        return True

    def _exclude_current_point(self, history: List[Dict], x: dict) -> List[Dict]:
        if not history:
            return []
        filtered = []
        for entry in history:
            x_hist = entry.get("x", {})
            if self._is_same_point(x_hist, x):
                continue
            filtered.append(entry)
        return filtered

    def _fallback_from_history(self, x: dict, history: List[Dict]) -> float:
        if not history:
            raise ValueError("Cannot perform fallback prediction from history (history is empty)")
        
        feature_order = list(x.keys())
        target_vec = np.array([float(x[k]) for k in feature_order], dtype=np.float64)
        
        best_entry = None
        best_distance = float("inf")
        
        for entry in history:
            x_hist = entry.get("x", {})
            vec = np.array([float(x_hist.get(k, 0.0)) for k in feature_order], dtype=np.float64)
            dist = np.linalg.norm(vec - target_vec)
            if dist < best_distance:
                best_distance = dist
                best_entry = entry
        
        if best_entry is None or best_entry.get("y") is None:
            raise ValueError("Cannot find a valid fallback prediction from history")
        
        return float(best_entry["y"])

