"""LLM generator: call LLM and parse numeric results."""
import re
import json
from datetime import datetime
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Any, Optional


class LLMGenerator:
    """LLM generator that calls the model and parses numeric outputs."""
    
    def __init__(
        self,
        llm_client,
        system_prompt: str = None,
        value_range: List[float] = None,
        log_path: Optional[str] = None,
        *,
        verbose: bool = False,
    ):
        self.llm_client = llm_client
        self.system_prompt = system_prompt
        self.value_range = value_range  # e.g., [-0.3, 1.5]
        self.verbose = bool(verbose)
        self.call_log = []  # detailed records for each call
        if log_path:
            self.log_path = Path(log_path)
        else:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            default_log_dir = Path("llm_logs")
            self.log_path = default_log_dir / f"llm_calls_{timestamp}.jsonl"
    
    def generate_single(
        self,
        user_prompt: str,
        x: dict,
        history: List[Dict] = None,
        seed: int = None,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048
    ) -> float:
        """Generate a single numeric prediction (using system + user prompt).

        Args:
            user_prompt: User prompt template containing {history_json} and {points_json} placeholders.
            x: Input point, e.g. {"fen": 87, "crn": 104, "hn": 220}.
            history: History list [{"x": {...}, "y": 0.85}, ...].
            seed: Random seed.
            temperature: Sampling temperature.
            top_p: Nucleus sampling parameter.
            max_tokens: Maximum number of generated tokens.

        Returns:
            Parsed numeric prediction.

        Raises:
            ValueError: If a numeric prediction cannot be parsed.
        """
        # Format history as JSON (use shared helpers from prompt.py)
        from .prompt import format_history_json, format_points_json
        feature_order = list(x.keys())
        history_json = format_history_json(history, feature_order)
        
        # Format a single input point as JSON (one-element array)
        points_json = format_points_json([x], feature_order)
        
        # Format user prompt with standard .format()
        try:
            formatted_user = user_prompt.format(
                history_json=history_json,
                points_json=points_json
            )
        except KeyError as e:
            raise
        
        # Build full prompt (prepend system prompt if present)
        if self.system_prompt:
            # Some APIs support separate system/user messages; we simply concatenate.
            full_prompt = f"{self.system_prompt}\n\n{formatted_user}"
        else:
            full_prompt = formatted_user
        
        call_record = {
            "input_prompt": full_prompt,
            "seed": seed,
            "temperature": temperature,
            "top_p": top_p,
            "max_tokens": max_tokens
        }
        
        try:
            # Call LLM
            response = self.llm_client.generate(
                full_prompt,
                seed=seed,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens
            )
            call_record["response"] = response
            
            # Parse numeric value from JSON (single point)
            input_dim = len(x) if isinstance(x, dict) else len(x) if hasattr(x, '__len__') else 3
            values = self._parse_json_predictions(response, n_points=1, input_dim=input_dim)
            value = values[0] if values and values[0] is not None else None
            
            if value is None:
                raise ValueError("Failed to parse prediction from JSON response")
            
            call_record["parsed_value"] = value
            call_record["in_range"] = self._is_valid_value(value)
            
            # Range check
            if not call_record["in_range"]:
                raise ValueError(f"Value {value:.6f} out of valid range {self.value_range}")
            
            return value
        
        except Exception as e:
            call_record["error"] = str(e)
            if "response" not in call_record:
                call_record["response"] = None
            raise
        
        finally:
            self.call_log.append(call_record)
            self._write_log(call_record)
    
    def generate_batch(
        self,
        user_prompts: List[str],
        x: dict,
        history: List[Dict] = None,
        seeds: List[int] = None,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048
    ) -> List[List[float]]:
        """Generate predictions for P user prompts × S seeds, stopping after the first success.

        Args:
            user_prompts: List of P user prompt templates.
            x: Input point.
            history: History data list.
            seeds: List of S random seeds.
            temperature: Sampling temperature.
            top_p: Nucleus sampling parameter.
            max_tokens: Maximum generated tokens.

        Returns:
            Nested list of shape (P, S) with predictions; only the first successful
            combination is filled, others are None.
        """
        if seeds is None:
            seeds = [None]
        
        # Clear previous logs
        self.call_log = []
        
        results = []
        success_found = False
        
        for p_idx, user_prompt in enumerate(user_prompts):
            prompt_results = []
            for s_idx, seed in enumerate(seeds):
                if success_found:
                    # Once a success is found, remaining entries are filled with None
                    prompt_results.append(None)
                else:
                    try:
                        value = self.generate_single(
                            user_prompt, x, history, seed, temperature, top_p, max_tokens
                        )
                        prompt_results.append(value)
                        success_found = True
                        if self.verbose:
                            pass
                        break  # break inner loop on success
                    except ValueError as e:
                        # Out-of-range or parse failure
                        if self.verbose:
                            pass
                        prompt_results.append(None)
                    except Exception as e:
                        if self.verbose:
                            pass
                        prompt_results.append(None)
            
            results.append(prompt_results)
            
            if success_found:
                # If a success has been found, stop processing further prompts
                break
        
        if not success_found:
            raise ValueError("All prompt×seed combinations failed")
        
        return results
    
    
    def _is_valid_value(self, value: float) -> bool:
        """Check whether a value is within the configured valid range."""
        if self.value_range is None:
            return True
        
        min_val, max_val = self.value_range
        return min_val <= value <= max_val
    
    def generate_batch_multi_points(
        self,
        user_prompt_template: str,
        X_batch: List[dict],
        history: List[Dict] = None,
        seed: int = None,
        temperature: float = 0.7,
        top_p: float = 0.9,
        max_tokens: int = 2048
    ) -> List[float]:
        """Use a single LLM call to predict multiple points in batch mode.

        Args:
            user_prompt_template: User prompt template containing {history_json} and {points_json}.
            X_batch: List of input points, e.g. [{"fen": 70, "crn": 120, "hn": 260}, ...].
            history: History list.
            seed: Random seed; if None, a default seed sequence will be tried.
            temperature: Sampling temperature.
            top_p: Nucleus sampling parameter.
            max_tokens: Maximum number of generated tokens.

        Returns:
            List of numeric predictions with the same length as X_batch.

        Raises:
            ValueError: If all attempts fail.
        """
        # If seed is None, try a small default sequence
        seeds_to_try = [seed] if seed is not None else [42, 43, 44]
        
        for attempt_seed in seeds_to_try:
            try:
                # Format history as JSON (use shared helpers)
                from .prompt import format_history_json, format_points_json
                feature_order = list(X_batch[0].keys()) if X_batch else None
                history_json = format_history_json(history, feature_order)
                
                # Format multiple input points as JSON
                points_json = format_points_json(X_batch, feature_order)
                
                # Debug: inspect JSON lengths (batch mode)
                if len(history_json) > 200:
                    if self.verbose:
                        pass
                if len(points_json) > 200:
                    if self.verbose:
                        pass
                
                # Format user prompt (JSON-based template)
                try:
                    formatted_user = user_prompt_template.format(
                        history_json=history_json,
                        points_json=points_json
                    )
                except KeyError as e:
                    raise
                
                # Build full prompt string
                if self.system_prompt:
                    full_prompt = f"{self.system_prompt}\n\n{formatted_user}"
                else:
                    full_prompt = formatted_user
                
                # Call LLM
                response = self.llm_client.generate(
                    full_prompt,
                    seed=attempt_seed,
                    temperature=temperature,
                    top_p=top_p,
                    max_tokens=max_tokens
                )
                
                # Parse batch results from JSON
                input_dim = len(X_batch[0]) if X_batch and isinstance(X_batch[0], dict) else 3
                values = self._parse_json_predictions(response, len(X_batch), input_dim=input_dim)
                
                # Check if any valid predictions exist
                valid_values = [v for v in values if v is not None]
                if valid_values:
                    if self.verbose:
                        pass
                    
                    # Record call details
                    call_record = {
                        "input_prompt": full_prompt,
                        "response": response,
                        "parsed_values": values,
                        "n_points": len(X_batch),
                        "seed": attempt_seed,
                        "temperature": temperature,
                        "top_p": top_p,
                        "max_tokens": max_tokens
                    }
                    self.call_log.append(call_record)
                    self._write_log(call_record)
                    
                    return values
                else:
                    if self.verbose:
                        pass
                    
            except Exception as e:
                if self.verbose:
                    pass
                continue
        
        # All attempts failed
        raise ValueError(f"All seeds failed, cannot predict {len(X_batch)} points")
    
    def _parse_json_predictions(self, response: str, n_points: int, input_dim: int = None) -> List[float]:
        """Parse numeric predictions from an LLM response in JSON-like formats."""
        import json
        
        # Strip whitespace
        response = response.strip()
        
        # Remove thinking tag if present
        if '</think>' in response:
            response = response.split('</think>')[-1].strip()
        
        # Strategy 0: directly parse full JSON matching the current prompt format
        try:
            parsed = json.loads(response)
            if isinstance(parsed, dict) and "data_points" in parsed:
                return self._extract_targets_from_data_points(parsed["data_points"], n_points)
        except json.JSONDecodeError:
            pass
        
        # Strategy 1: extract {"data_points": [...]} blocks
        try:
            # Look for {"data_points": [...]} pattern
            json_match = re.search(r'\{\s*"data_points"\s*:\s*\[(.*?)\]\s*\}', response, re.DOTALL)
            if json_match:
                # Extract data_points array
                data_points_str = '[' + json_match.group(1) + ']'
                
                # Remove control characters
                data_points_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', data_points_str)
                
                data_points = json.loads(data_points_str)
                return self._extract_targets_from_data_points(data_points, n_points)
        
        except Exception as e:
            if self.verbose:
                pass
        
        # Strategy 2: fallback to "predictions" format (legacy)
        try:
            json_match = re.search(r'\{[^}]*"predictions"\s*:\s*\[([^\]]+)\][^}]*\}', response, re.DOTALL)
            if json_match:
                predictions_str = '[' + json_match.group(1) + ']'
                # Remove control characters
                predictions_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', predictions_str)
                predictions = json.loads(predictions_str)
                
                if len(predictions) == n_points:
                    if self.verbose:
                        pass
                    
                    # Range filtering
                    filtered_results = []
                    for i, value in enumerate(predictions):
                        if isinstance(value, (int, float)):
                            if self._is_valid_value(float(value)):
                                filtered_results.append(float(value))
                            else:
                                filtered_results.append(None)
                        else:
                            filtered_results.append(None)
                    
                    return filtered_results
        
        except Exception as e:
            if self.verbose:
                pass
        
        # Strategy 3: fallback to raw JSON array of numbers
        try:
            # Look for a numeric array like [0.856, 0.870, ...]
            array_match = re.search(r'\[([-+\d\.,\s]+)\]', response)
            if array_match:
                predictions_str = '[' + array_match.group(1) + ']'
                # Remove control characters
                predictions_str = re.sub(r'[\x00-\x1f\x7f-\x9f]', '', predictions_str)
                predictions = json.loads(predictions_str)
                
                # Compute total dimensionality per point (inputs + 1 output)
                if input_dim is None:
                    input_dim = 3  # default input dimension for backward compatibility
                total_dim = input_dim + 1
                
                if len(predictions) >= n_points * total_dim:
                    filtered_results = []
                    for i in range(n_points):
                        # Use the last coordinate per block (index input_dim) as output
                        output_idx = i * total_dim + input_dim
                        if output_idx < len(predictions):
                            value = predictions[output_idx]
                            if isinstance(value, (int, float)):
                                float_value = float(value)
                                if self._is_valid_value(float_value):
                                    filtered_results.append(float_value)
                                else:
                                    filtered_results.append(None)
                            else:
                                filtered_results.append(None)
                        else:
                            filtered_results.append(None)
                    
                    if self.verbose:
                        pass
                    return filtered_results
                elif len(predictions) >= n_points:
                    # If we cannot extract by dimensionality, use first n_points entries
                    predictions = predictions[:n_points]
                    if self.verbose:
                        pass
                    
                    # Range filtering
                    filtered_results = []
                    for i, value in enumerate(predictions):
                        if isinstance(value, (int, float)):
                            if self._is_valid_value(float(value)):
                                filtered_results.append(float(value))
                            else:
                                filtered_results.append(None)
                        else:
                            filtered_results.append(None)
                    
                    return filtered_results
        
        except Exception as e:
            if self.verbose:
                pass
        
        # Strategy 4: final fallback – extract all numbers and slice by dimensionality
        if self.verbose:
            pass
        all_numbers = re.findall(r'([-+]?\d+\.?\d*)', response)
        
        # Compute total dimensionality per point (inputs + 1 output)
        if input_dim is None:
            input_dim = 3  # default input dimension for backward compatibility
        total_dim = input_dim + 1
        
        if len(all_numbers) >= n_points * total_dim:
            results = []
            for i in range(n_points):
                # Use the last coordinate per block (index input_dim) as output
                output_idx = i * total_dim + input_dim
                if output_idx < len(all_numbers):
                    try:
                        value = float(all_numbers[output_idx])
                        if self._is_valid_value(value):
                            results.append(value)
                        else:
                            results.append(None)
                    except:
                        results.append(None)
                else:
                    results.append(None)
            if self.verbose:
                pass
            return results
        else:
            if self.verbose:
                pass
            return [None] * n_points
        
        raise ValueError(
            f"Failed to parse {n_points} predictions. "
            f"Response: {response[:500]}"
        )

    def _extract_targets_from_data_points(self, data_points: Any, n_points: int) -> List[float]:
        """Extract target values from a data_points list, matching current prompt format."""
        if not isinstance(data_points, list):
            raise ValueError("data_points field is not a list")
        
        predictions = []
        for item in data_points:
            value = None
            if isinstance(item, dict):
                value = item.get("target")
            if isinstance(value, (int, float)):
                float_val = float(value)
                if self._is_valid_value(float_val):
                    predictions.append(float_val)
                else:
                    predictions.append(None)
            else:
                predictions.append(None)
        
        count = len(predictions)
        if count == n_points:
            if self.verbose:
                pass
        elif count > n_points:
            predictions = predictions[:n_points]
        else:
            predictions.extend([None] * (n_points - count))
        
        return predictions

    def _write_log(self, record: Dict[str, Any]):
        """Append a single LLM call record to a JSONL log file (optional)."""
        if self.log_path is None:
            return
        try:
            self.log_path.parent.mkdir(parents=True, exist_ok=True)
            with self.log_path.open("a", encoding="utf-8") as f:
                json.dump(record, f, ensure_ascii=False)
                f.write("\n")
        except Exception:
            if self.verbose:
                pass

