"""
Prompt Optimization

This module implements TextGrad-style optimization for the objective discovery prompt.
When discovered objectives fail the interpretability check, it computes textual gradients
and updates the prompt to improve interpretability.

Reference: Yuksekgonul et al., "TextGrad: Automatic 'Differentiation' via Text"
"""

from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass
from openai import OpenAI
import numpy as np

try:
    from .constants import OPENAI_API_KEY, DATASET_NAMES_DICT
    from .objectives_verifiers import HumanInterpretableVerifier
except ImportError:
    from constants import OPENAI_API_KEY, DATASET_NAMES_DICT
    from objectives_verifiers import HumanInterpretableVerifier


@dataclass
class InterpretabilityFeedback:
    """Feedback from interpretability check for a single objective."""
    objective: str
    avg_difference: float  # Average |r_n(x,y) - s_h(x,y|n)|
    is_interpretable: bool
    score_details: Dict[str, Any]  # Full details from HumanInterpretableVerifier


class PromptOptimizer:
    """
    TextGrad-based optimizer for the objective discovery prompt.

    When objectives fail the interpretability check, this optimizer:
    1. Collects interpretability results from multiple failed objectives
    2. Computes textual gradients for each objective w.r.t. the prompt
    3. Aggregates gradients across all failed objectives
    4. Updates the prompt using the aggregated gradient
    """

    def __init__(
        self,
        model: str = "gpt-4o-mini",
        temperature: float = 0.7,
        api_key: Optional[str] = None
    ):
        self.model = model
        self.temperature = temperature
        self.client = OpenAI(api_key=api_key or OPENAI_API_KEY)
        self.optimization_history: List[Dict[str, Any]] = []

    def _llm_call(self, prompt: str, system_prompt: str = "") -> str:
        """Make an LLM call."""
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        response = self.client.chat.completions.create(
            model=self.model,
            messages=messages,
            temperature=self.temperature,
            max_tokens=2048
        )
        return response.choices[0].message.content.strip()

    def compute_gradient(
        self,
        current_prompt: str,
        feedback: InterpretabilityFeedback
    ) -> str:
        """
        Compute textual gradient for a single objective w.r.t. the discovery prompt.

        Uses the interpretability check results from HumanInterpretableVerifier.
        """
        # Extract per-model statistics if available
        per_model_info = ""
        if 'per_human_model_stats' in feedback.score_details:
            stats = feedback.score_details['per_human_model_stats']
            per_model_info = "\n".join([
                f"  - {model}: avg={s['avg']:.2f}, std={s['std']:.2f}"
                for model, s in stats.items()
            ])

        gradient_prompt = f"""You are an optimization expert analyzing why a discovered objective failed the interpretability check.

DISCOVERY PROMPT (used to generate objectives):
{current_prompt}...

OBJECTIVE THAT FAILED:
{feedback.objective}

INTERPRETABILITY CHECK RESULTS:
- Average difference between objective scorer and human scorers: {feedback.avg_difference:.4f}
- Average objective score: {feedback.score_details.get('avg_objective_score', 0):.2f}
- Average human score: {feedback.score_details.get('avg_human_score', 0):.2f}
- Per-model human scorer statistics:
{per_model_info}

PROBLEM:
The objective was scored very differently by the objective scorer vs human scorers,
indicating the objective is ambiguous or unclear in its meaning.

YOUR TASK:
Provide specific, actionable feedback on how to modify the DISCOVERY PROMPT to generate
objectives that are:
1. More specific and concrete (not vague or abstract)
2. More consistently interpretable across different evaluators
3. Clearer in describing measurable behaviors

Focus ONLY on prompt improvements. Do NOT suggest changes to the objective itself.

GRADIENT (specific feedback for the discovery prompt):"""

        return self._llm_call(gradient_prompt)

    def aggregate_gradients(self, gradients: List[str]) -> str:
        """Aggregate multiple textual gradients into a single coherent gradient."""
        if len(gradients) == 1:
            return gradients[0]

        gradients_text = "\n\n---\n\n".join([
            f"FEEDBACK {i+1}:\n{g}" for i, g in enumerate(gradients)
        ])

        aggregation_prompt = f"""Synthesize {len(gradients)} pieces of feedback into unified recommendations.

{gradients_text}

YOUR TASK:
Combine these into a single coherent set of recommendations:
- Identify general themes
- Prioritize changes that address multiple issues
- Keep it generic
- Remove redundant suggestions
- Keep recommendations specific and actionable

AGGREGATED RECOMMENDATIONS:"""

        return self._llm_call(aggregation_prompt)

    def apply_gradient(self, current_prompt: str, gradient: str) -> str:
        """Apply the textual gradient to update the prompt."""
        update_prompt = f"""Improve the following prompt based on the feedback.

CURRENT PROMPT:
{current_prompt}

FEEDBACK:
{gradient}

REQUIREMENTS:
1. Make targeted changes addressing the feedback
2. Preserve format placeholders: {{trajectory_count}}, {{trajectories}}, {{num_objectives}}, {{existing_objectives_section}}
3. Keep similar length
4. Still request a numbered list of objectives

Output ONLY the improved prompt:"""

        return self._llm_call(update_prompt)

    def optimize_step(
        self,
        current_prompt: str,
        failed_feedbacks: List[InterpretabilityFeedback]
    ) -> Tuple[str, Dict[str, Any]]:
        """
        Perform one TextGrad optimization step.

        Args:
            current_prompt: The current discovery prompt
            failed_feedbacks: List of InterpretabilityFeedback from failed objectives

        Returns:
            Tuple of (updated_prompt, step_info)
        """
        if not failed_feedbacks:
            return current_prompt, {"skipped": True, "reason": "No failed objectives"}

        # Compute gradients for each failed objective
        gradients = []
        for feedback in failed_feedbacks:
            grad = self.compute_gradient(current_prompt, feedback)
            gradients.append(grad)

        # Aggregate gradients
        aggregated_gradient = self.aggregate_gradients(gradients)

        # Apply gradient to update prompt
        updated_prompt = self.apply_gradient(current_prompt, aggregated_gradient)

        # Record step info
        step_info = {
            "num_failed_objectives": len(failed_feedbacks),
            "objectives": [f.objective for f in failed_feedbacks],
            "avg_differences": [f.avg_difference for f in failed_feedbacks],
            "mean_avg_difference": float(np.mean([f.avg_difference for f in failed_feedbacks])),
            "individual_gradients": gradients,
            "aggregated_gradient": aggregated_gradient
        }

        self.optimization_history.append(step_info)

        return updated_prompt, step_info

    def get_history(self) -> List[Dict[str, Any]]:
        """Return the optimization history."""
        return self.optimization_history


def compute_interpretability_for_objectives(
    objectives: List[str],
    verification_samples: List[Dict[str, str]],
    model_sequence: List[str],
    scorer_model: str,
    human_scorer_models: List[str],
    dataset_type: str,
    epsilon: float = 0.15
) -> List[InterpretabilityFeedback]:
    """
    Compute interpretability scores for multiple objectives using HumanInterpretableVerifier.

    This reuses the existing verification infrastructure to get interpretability
    scores for multiple objectives, which are then used to compute textual gradients.

    Args:
        objectives: List of objective descriptions to evaluate
        verification_samples: Dataset samples for verification
        model_sequence: List of model checkpoint paths
        scorer_model: Model for objective scoring
        human_scorer_models: List of models for human scoring
        dataset_type: Type of dataset (e.g., 'hh', 'tldr')
        epsilon: Threshold for interpretability

    Returns:
        List of InterpretabilityFeedback for each objective
    """
    feedbacks = []

    for objective in objectives:
        print(f"Computing interpretability for: {objective[:50]}...")

        verifier = HumanInterpretableVerifier(
            objective_description=objective,
            objective_model=scorer_model,
            human_models=human_scorer_models,
            model_sequence=model_sequence,
            epsilon=epsilon
        )

        try:
            avg_difference, is_interpretable, score_details = verifier.compute_alignment(
                dataset=verification_samples,
                dataset_type=dataset_type,
                use_provided_responses=False
            )

            feedback = InterpretabilityFeedback(
                objective=objective,
                avg_difference=avg_difference,
                is_interpretable=is_interpretable,
                score_details=score_details
            )
            feedbacks.append(feedback)

        except Exception as e:
            print(f"Error computing interpretability for '{objective}': {e}")
            # Create a default feedback for failed computation
            feedback = InterpretabilityFeedback(
                objective=objective,
                avg_difference=float('inf'),
                is_interpretable=False,
                score_details={"error": str(e)}
            )
            feedbacks.append(feedback)

    return feedbacks


def optimize_discovery_prompt(
    current_prompt: str,
    objectives: List[str],
    verification_samples: List[Dict[str, str]],
    model_sequence: List[str],
    scorer_model: str,
    human_scorer_models: List[str],
    dataset_type: str,
    epsilon: float = 0.15,
    optimizer_model: str = "gpt-4o-mini",
    logger=None
) -> Tuple[str, Dict[str, Any]]:
    """
    Optimize the discovery prompt based on interpretability of objectives.

    This is the main entry point for prompt optimization. It:
    1. Computes interpretability scores for all provided objectives
    2. Creates a PromptOptimizer
    3. Computes and aggregates textual gradients
    4. Returns the updated prompt

    Args:
        current_prompt: The current discovery prompt to optimize
        objectives: List of objectives to compute interpretability for
        verification_samples: Dataset samples for verification
        model_sequence: List of model checkpoint paths
        scorer_model: Model for objective scoring
        human_scorer_models: List of models for human scoring
        dataset_type: Type of dataset (e.g., 'hh', 'tldr')
        epsilon: Threshold for interpretability
        optimizer_model: LLM model for computing gradients
        logger: Optional logger

    Returns:
        Tuple of (updated_prompt, optimization_info)
    """
    if not objectives:
        return current_prompt, {"skipped": True, "reason": "No objectives provided"}

    print(f"\n[Prompt Optimization] Computing interpretability for {len(objectives)} objectives...")

    if logger:
        logger.info("\n" + "="*60)
        logger.info("PROMPT OPTIMIZATION")
        logger.info("="*60)
        logger.info(f"Objectives to evaluate: {len(objectives)}")

    # Compute interpretability for all objectives
    feedbacks = compute_interpretability_for_objectives(
        objectives=objectives,
        verification_samples=verification_samples,
        model_sequence=model_sequence,
        scorer_model=scorer_model,
        human_scorer_models=human_scorer_models,
        dataset_type=dataset_type,
        epsilon=epsilon
    )

    # Create optimizer and perform optimization step
    optimizer = PromptOptimizer(model=optimizer_model)
    updated_prompt, step_info = optimizer.optimize_step(
        current_prompt=current_prompt,
        failed_feedbacks=feedbacks
    )

    if logger:
        logger.info(f"Mean avg_difference: {step_info.get('mean_avg_difference', 'N/A')}")
        agg_grad = step_info.get('aggregated_gradient', 'N/A')
        logger.info(f"Aggregated gradient:\n{agg_grad[:500] if isinstance(agg_grad, str) else agg_grad}...")
        logger.info("Discovery prompt updated.")

    print(f"[Prompt Optimization] Prompt updated based on {len(objectives)} objectives.")
    return updated_prompt, step_info
