"""Warm-up utilities."""
import json
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional
from .prompt import get_warmup_prompt
from .predictor import LowFidelityPredictor
from koh.utils import numpy_to_dict_list

# Fixed seeds for warm-up phase (decoupled from main optimization seed)
DEFAULT_WARMUP_RANDOM_SEED = 42
DEFAULT_WARMUP_LLM_SEEDS = list(range(1, 21))

# Helpers to save/restore random state so warm-up does not affect main loop seeds
def _save_random_state():
    """Save current NumPy random state."""
    return np.random.get_state()

def _restore_random_state(state):
    """Restore NumPy random state."""
    np.random.set_state(state)


def parse_initial_points(response: str, feature_names: List[str], *, verbose: bool = False) -> List[dict]:
    """Parse initial points generated by the LLM from a JSON-like response.

    Expected primary format:
        {"points": [{"fen": 70, "crn": 120, "hn": 260}, {"fen": 75, ...}, ...]}

    Returns:
        List of points like [{"fen": 70, "crn": 120, "hn": 260}, ...].
    Raises:
        ValueError if parsing fails.
    """
    import re
    import json
    
    # Strip response
    response = response.strip()
    
    # Remove thinking tag if present
    if '</think>' in response:
        response = response.split('</think>')[-1].strip()
    
    # Strategy 1: JSON format (preferred)
    try:
        # Look for {"points": [...]} pattern
        json_match = re.search(r'\{\s*"points"\s*:\s*\[(.*?)\]\s*\}', response, re.DOTALL)
        if json_match:
            points_str = '[' + json_match.group(1) + ']'
            points = json.loads(points_str)
            
            # Ensure each point contains all required features
            valid_points = []
            for point in points:
                if isinstance(point, dict) and all(feat in point for feat in feature_names):
                    valid_points.append(point)
            
            if valid_points:
                return valid_points
    
    except Exception:
        pass
    
    # Strategy 2: fallback – legacy text format "[N] fen=X, crn=Y, hn=Z"
    points = []
    lines = response.split('\n')
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
        
        # Match "[N] fen=X, crn=Y, hn=Z"
        match = re.search(r'\[?\d+\]?\s*(.+)', line)
        if match:
            point_str = match.group(1)
            
            # Parse key=value pairs
            point = {}
            for pair in point_str.split(','):
                pair = pair.strip()
                if '=' in pair:
                    key, value = pair.split('=')
                    key = key.strip()
                    value = value.strip()
                    
                    try:
                        point[key] = int(value)
                    except ValueError:
                        try:
                            point[key] = float(value)
                        except ValueError:
                            continue
            
            if all(feat in point for feat in feature_names):
                points.append(point)
    
    if len(points) == 0:
        raise ValueError(f"Failed to parse any initial points from response: {response[:500]}")
    
    return points


def _lhs_sample(bounds: np.ndarray, n_samples: int, seed: int = 42) -> np.ndarray:
    """Generate well-spread samples by Latin Hypercube Sampling (LHS).

    Args:
        bounds: (d, 2) array with [min, max] per dimension.
        n_samples: number of samples.
        seed: random seed.
    """
    rng = np.random.default_rng(seed)
    d = len(bounds)
    samples = np.zeros((n_samples, d))
    
    for i in range(d):
        low, high = bounds[i]
        # Split interval into n_samples equal segments
        segment_width = (high - low) / n_samples
        # Sample randomly within each segment
        segments = rng.permutation(n_samples)
        samples[:, i] = low + segments * segment_width + rng.uniform(0, segment_width, n_samples)
    
    return samples


def _log_llm_call(log_path: Optional[str], record: Dict, *, verbose: bool = False):
    """Append a single LLM call record to the shared log."""
    if not log_path:
        return
    try:
        path = Path(log_path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("a", encoding="utf-8") as f:
            json.dump(record, f, ensure_ascii=False)
            f.write("\n")
    except Exception:
        return


def generate_initial_points_with_llm(
    llm_client,
    task_name: str,
    feature_names: List[str],
    n_points: int = 5,
    llm_seed_sequence: Optional[List[int]] = None,
    log_path: Optional[str] = None,
    max_tokens: int = 2048,
    *,
    verbose: bool = False,
) -> List[dict]:
    """Use the LLM to generate warm-up initial points."""
    system_prompt, user_prompt = get_warmup_prompt(task_name, n_points)
    
    full_prompt = f"{system_prompt}\n\n{user_prompt}"
    
    # default: silent
    
    # Use a fixed seed sequence and take the first successful one for reproducibility
    seed_candidates = list(llm_seed_sequence or DEFAULT_WARMUP_LLM_SEEDS)
    points = None
    last_error = None
    
    # First try deterministic config temperature=0.0, top_p=1.0; then a stochastic config.
    temperature_configs = [
        (0.0, 1.0, "temperature=0.0"),
        (0.7, 0.9, "temperature=0.7")
    ]
    
    for temp, top_p_val, temp_desc in temperature_configs:
        if points is not None:
            break  # stop if already successful
        
        if verbose:
            pass
        for attempt_seed in seed_candidates:
            call_record = {
                "stage": "warmup_initial_points",
                "task": task_name,
                "seed": attempt_seed,
                "temperature": temp,
                "top_p": top_p_val,
                "max_tokens": max_tokens,
                "input_prompt": full_prompt
            }
            try:
                response = llm_client.generate(
                    full_prompt,
                    seed=attempt_seed,
                    temperature=temp,
                    top_p=top_p_val,
                    max_tokens=max_tokens
                )
                call_record["response"] = response
                points = parse_initial_points(response, feature_names, verbose=verbose)
                call_record["status"] = "success"
                call_record["n_points"] = len(points) if points else 0
                _log_llm_call(log_path, call_record, verbose=verbose)
                break
            except Exception as e:
                last_error = e
                call_record.setdefault("response", None)
                call_record["status"] = "error"
                call_record["error"] = str(e)
                _log_llm_call(log_path, call_record, verbose=verbose)
                continue
    
    if points is None:
        raise ValueError(f"LLM failed to generate initial points. Tried seeds: {seed_candidates}, last error: {last_error}")
    
    # Check number of parsed points
    if len(points) >= n_points:
        points = points[:n_points]  # truncate to first n points
    return points


def warmup_phase(
    llm_client,
    hf_blackbox,
    data_manager,
    generator,
    user_prompt: str,
    task_name: str,
    feature_names: List[str],
    n_initial_points: int = 5,
    temperature: float = 0.7,
    top_p: float = 0.9,
    max_tokens: int = 2048,
    warmup_random_seed: int = DEFAULT_WARMUP_RANDOM_SEED,
    llm_seed_sequence: Optional[List[int]] = None,
    fixed_initial_points: Optional[List[dict]] = None,
    y_transform: float = 1.0,
    *,
    verbose: bool = False,
):
    """Warm-up phase: LLM-generated initial points → HF experiments → leave-one-out ICL.

    If fixed_initial_points is provided, LLM generation is skipped and those are used.
    """
    # Use a fixed seed for warm-up so all runs share the same initial and exploration points.
    WARMUP_FIXED_SEED = warmup_random_seed
    
    # Save current random state (if previously configured)
    saved_random_state = _save_random_state()
    
    if verbose:
        pass
    
    # 1. Get initial points (fixed or generated by LLM)
    if fixed_initial_points is not None and len(fixed_initial_points) > 0:
        # Validate fixed point format
        if not isinstance(fixed_initial_points[0], dict):
            raise ValueError(f"fixed_initial_points must be a list of dicts, got: {type(fixed_initial_points[0])}")
        # Verify all feature names are present
        for i, point in enumerate(fixed_initial_points):
            missing_features = [feat for feat in feature_names if feat not in point]
            if missing_features:
                raise ValueError(f"Fixed initial point {i+1} is missing features: {missing_features}")
        
        initial_points = fixed_initial_points[:n_initial_points]  # truncate to first n points
    else:
        # LLM-generated initial points
        log_path = getattr(generator, "log_path", None)
        initial_points = generate_initial_points_with_llm(
            llm_client,
            task_name,
            feature_names,
            n_initial_points,
            llm_seed_sequence=llm_seed_sequence,
            log_path=log_path,
            max_tokens=max_tokens,
            verbose=verbose,
        )
    
    # 2. Run HF experiments on initial points
    for i, x in enumerate(initial_points, 1):
        y_H = hf_blackbox.evaluate(x)
        
        # Add to seed_points and history
        data_manager.add_seed_point(x, y_H)
        
        # Also store in hf_predictions (iteration=0 for warm-up)
        data_manager.add_hf_experiment(x, y_H, iteration=0)
    
    # Persist history to disk
    data_manager.save_all()
    
    # 3. Leave-one-out ICL predictions for these initial points
    predictor = LowFidelityPredictor(
        generator=generator,
        user_prompts=user_prompt,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        y_transform=y_transform,
        verbose=verbose,
    )
    
    lf_predictions_success = []
    
    for i, x in enumerate(initial_points, 1):
        # Get history excluding this point
        history_exclude = data_manager.get_history_exclude_point(x)
        
        # Single-point ICL prediction
        try:
            mu_LF, sigma2_LF, details = predictor.predict(x, history_exclude)
            
            # Save to lf_predictions.csv (iteration=0 for warm-up)
            data_manager.add_lf_prediction(x, mu_LF, sigma2_LF, iteration=0)
            lf_predictions_success.append((x, mu_LF, sigma2_LF))
            
        except Exception:
            if verbose:
                pass
    
    # 4. Initial LF exploration: randomly sample a batch of points and use HF points as history.
    #    For high-dimensional Sandwich (20D), use more initial LF points to cover space.
    if task_name == "Sandwich":
        n_exploration_points = 50
    else:
        n_exploration_points = 20
    
    # Get bounds and convert to NumPy array
    bounds = np.array(hf_blackbox.bounds)  # shape (d, 2)
    
    # Use LHS to sample exploration points (ensure spread) with fixed seed
    exploration_points = numpy_to_dict_list(
        _lhs_sample(bounds, n_exploration_points, seed=WARMUP_FIXED_SEED), 
        feature_names
    )
    
    # Ensure exploration points are not too close to initial HF points (distance threshold 0.1)
    if len(initial_points) > 0:
        initial_points_array = np.array([[x[feat] for feat in feature_names] for x in initial_points])
        min_distance_threshold = 0.1
        
        def _min_distance_to_initial(point_dict):
            """Compute minimum distance from a point to initial points."""
            point_array = np.array([point_dict[feat] for feat in feature_names])
            return np.min(np.linalg.norm(initial_points_array - point_array, axis=1))
        
        # Filter points that are too close and resample them.
        # Use fixed indices and seeds so resampling is reproducible.
        exploration_points_filtered = []
        for idx, exp_point in enumerate(exploration_points):
            if _min_distance_to_initial(exp_point) > min_distance_threshold:
                exploration_points_filtered.append(exp_point)
            else:
                # Resample until distance requirement is satisfied (up to 10 retries).
                # Use deterministic seeds: WARMUP_FIXED_SEED + 10000 + idx * 100 + retry
                replacement_found = False
                for retry in range(10):
                    replacement_seed = WARMUP_FIXED_SEED + 10000 + idx * 100 + retry
                    new_point_array = _lhs_sample(bounds, 1, seed=replacement_seed)[0]
                    new_point = {feat: float(new_point_array[i]) for i, feat in enumerate(feature_names)}
                    if _min_distance_to_initial(new_point) > min_distance_threshold:
                        exploration_points_filtered.append(new_point)
                        replacement_found = True
                        break
                if not replacement_found:
                    # If resampling fails, keep original point and log a warning
                    exploration_points_filtered.append(exp_point)
                    if verbose:
                        pass
        
        # Ensure we end up with exactly n_exploration_points
        if len(exploration_points_filtered) < n_exploration_points:
            needed = n_exploration_points - len(exploration_points_filtered)
            # Use fixed seed range to ensure supplemental samples are reproducible
            additional = numpy_to_dict_list(_lhs_sample(bounds, needed, seed=WARMUP_FIXED_SEED + 20000), feature_names)
            exploration_points_filtered.extend(additional)
        
        exploration_points = exploration_points_filtered[:n_exploration_points]
    
    
    # Get full history (including all HF points)
    full_history = data_manager.get_history_data()
    
    # Batch-predict exploration points (batch_size depends on task dimension)
    input_dim = len(feature_names)
    batch_size = 10 if input_dim >= 10 else (15 if input_dim >= 5 else 20)
    
    try:
        mu_LF_batch, sigma2_LF_batch = predictor.predict_batch(
            exploration_points,
            history=full_history,
            batch_size=batch_size
        )
        
        # Save all LF prediction points
        success_count = 0
        for i, (x, mu_LF, sigma2_LF) in enumerate(zip(exploration_points, mu_LF_batch, sigma2_LF_batch)):
            if not (np.isnan(mu_LF) or np.isnan(sigma2_LF)):
                data_manager.add_lf_prediction(x, float(mu_LF), float(sigma2_LF), iteration=0)
                success_count += 1
        if verbose:
            pass
        
    except Exception:
        if verbose:
            pass
        success_count = 0
        for i, x in enumerate(exploration_points):
            try:
                mu_LF, sigma2_LF, _ = predictor.predict(x, full_history)
                data_manager.add_lf_prediction(x, mu_LF, sigma2_LF, iteration=0)
                success_count += 1
            except Exception:
                if verbose:
                    pass
        if verbose:
            pass
    
    # Save data
    data_manager.save_all()
    
    # Restore previous random state
    _restore_random_state(saved_random_state)
    
    if verbose:
        pass

