import re
import random
import numpy as np
from deap import gp, creator
import openai
import httpx

class BatchLLMSimplifier:
    """
    Provides batch simplification of multiple GP individuals using LLMs.
    Enhanced to support different simplification focuses and island-specific strategies.
    """
    
    def __init__(self, api_key, model="gpt-4o-mini", provider="openai", openrouter_base_url="https://openrouter.ai/api/v1"):
        """
        Initialize with API key and model selection.
        
        Args:
            api_key: API key for the chosen provider
            model: Model identifier (e.g., "gpt-4", "claude-2", "mistral-7b")
            provider: LLM provider ("openai" or "openrouter")
            openrouter_base_url: Base URL for OpenRouter API
        """
        self.api_key = api_key
        self.model = model
        self.provider = provider.lower()
        self.openrouter_base_url = openrouter_base_url
        
        # Configure client based on provider
        if self.provider == "openai":
            openai.api_key = api_key
        elif self.provider == "openrouter":
            self.http_client = httpx.Client(
                base_url=openrouter_base_url,
                headers={
                    "Authorization": f"Bearer {api_key}",
                    "HTTP-Referer": "https://github.com/your-repo",  # Replace with your actual referer
                    "X-Title": "Symbolic Regression LLM Simplifier"
                }
            )
        else:
            raise ValueError(f"Unsupported provider: {provider}")
            
        self.problem_description = None
        self.variable_ranges = None
        self.simplification_history = []
        self.max_history = 20  # Keep only recent history to avoid token limits
        
    def set_problem_context(self, description, variable_ranges=None):
        """Set the problem description and variable ranges for context."""
        self.problem_description = description
        self.variable_ranges = variable_ranges
        
    def add_simplification_to_history(self, original, simplified, original_error, new_error):
        """Track successful simplification in history."""
        self.simplification_history.append({
            "original": original,
            "simplified": simplified,
            "original_error": original_error,
            "new_error": new_error,
        })
        
        # Keep only recent history
        if len(self.simplification_history) > self.max_history:
            self.simplification_history.pop(0)
    
    def batch_simplify(self, individuals, pset, toolbox, batch_size=20, simplification_focus="balance", same_prompt=False):
        """
        Simplify multiple individuals in a single batch LLM call.
        
        Args:
            individuals: List of GP individuals to simplify
            pset: Primitive set used in GP
            toolbox: DEAP toolbox
            batch_size: Number of individuals to process in one LLM call
            simplification_focus: Focus for simplification ("generalization", "simplicity", or "balance")
            
        Returns:
            List of simplified individuals
        """
        # Process individuals in batches to avoid token limits
        simplified_individuals = []
        
        for i in range(0, len(individuals), batch_size):
            batch = individuals[i:i+batch_size]
            
            # Generate prompt with all expressions in the batch
            prompt = self._generate_batch_prompt(batch, simplification_focus)
            
            # Get simplified expressions from LLM
            simplified_exprs = self._get_batch_simplifications(prompt, len(batch), simplification_focus, same_prompt)
            
            # Convert simplified expressions to individuals
            for j, expr_str in enumerate(simplified_exprs):
                if j < len(batch):  # Safety check
                    try:
                        simp_ind = creator.Individual(gp.PrimitiveTree.from_string(expr_str, pset))
                        
                        # Evaluate the new individual
                        simp_ind.fitness.values = toolbox.evaluate(simp_ind)
                        
                        # Calculate error rate improvement
                        original_error = batch[j].fitness.values[0]
                        new_error = simp_ind.fitness.values[0]
                        error_improvement = ((original_error - new_error) / original_error) * 100 if original_error != 0 else 0
                        # only add 3 per step max
                        if j < 3:
                            self.add_simplification_to_history(str(batch[j]), expr_str, original_error, new_error)
                        simplified_individuals.append(simp_ind)
                    except Exception as e:
                        print(f"Error creating individual from LLM response: {e}")
                        # If conversion fails, skip this individual
                        continue
            
        return simplified_individuals
        
    def _generate_batch_prompt(self, individuals, simplification_focus="balance"):
        """Generate a prompt for batch simplification with specific focus."""
        prompt_parts = []
        
        # Add problem description
        if self.problem_description:
            prompt_parts.append(f"PROBLEM CONTEXT:\n{self.problem_description}")
            
            if self.variable_ranges:
                var_ranges_str = ""
                if len(self.variable_ranges) == 1:
                    var_ranges_str = f"Variable x is in range {self.variable_ranges[0]}"
                else:
                    var_ranges_str = ", ".join([f"x{i+1} ∈ {r}" for i, r in enumerate(self.variable_ranges)])
                prompt_parts.append(f"Variable domains: {var_ranges_str}")
        
        # Add simplification history examples
        if self.simplification_history:
            prompt_parts.append("PREVIOUS SUCCESSFUL SIMPLIFICATIONS:")
            for i, entry in enumerate(self.simplification_history[-3:]):  # Just the most recent 3
                prompt_parts.append(f"Example {i+1}:")
                prompt_parts.append(f"Original: {entry['original']}")
                prompt_parts.append(f"Simplified to: {entry['simplified']}")
                prompt_parts.append(f"Original error: {entry['original_error']:.4e}")
                prompt_parts.append(f"New error: {entry['new_error']:.4e}")
        
        # Add expressions to simplify (without focus-specific guidance)
        prompt_parts.append("EXPRESSIONS TO SIMPLIFY:")
        for i, ind in enumerate(individuals):
            prompt_parts.append(f"Expression {i+1}: {str(ind)}")
            prompt_parts.append(f"Current error: {ind.fitness.values[0]:.4e}")
        
        return "\n".join(prompt_parts)
    
    def _get_batch_simplifications(self, prompt, num_expressions, simplification_focus="balance", same_prompt=False):
        """Get simplified expressions from LLM for a batch of expressions."""
        try:
            if same_prompt:
                system_prompt = self._get_generalization_system_prompt()
            else:
                # Select system prompt based on focus
                if simplification_focus == "generalization":
                    system_prompt = self._get_generalization_system_prompt()
                elif simplification_focus == "simplicity":
                    system_prompt = self._get_simplicity_system_prompt()
                else:  # balance
                    system_prompt = self._get_balanced_system_prompt()
            
            # Create messages payload
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt}
            ]

            if self.provider == "openai":
                response = openai.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    max_tokens=1000
                )
                content = response.choices[0].message.content.strip()
            
            elif self.provider == "openrouter":
                response = self.http_client.post(
                    "/chat/completions",
                    json={
                        "model": self.model,
                        "messages": messages,
                        "max_tokens": 1000
                    }
                )
                response.raise_for_status()
                content = response.json()["choices"][0]["message"]["content"].strip()
            
            # Extract simplified expressions from response
            simplified_exprs = []
            
            # Pattern to match numbered expressions
            pattern = r'^\d+[\.\)]\s*(.*?)$'
            
            for line in content.split('\n'):
                line = line.strip()
                match = re.match(pattern, line)
                if match:
                    expr = match.group(1).strip()
                    simplified_exprs.append(expr)
            
            # If we couldn't parse properly, try an alternative approach
            if len(simplified_exprs) < num_expressions:
                # Split by lines and try to extract expressions
                lines = [line.strip() for line in content.split('\n') if line.strip()]
                simplified_exprs = []
                for line in lines[:num_expressions]:
                    # Remove numbering if present
                    cleaned = re.sub(r'^\d+[\.\)]\s*', '', line).strip()
                    if cleaned:
                        simplified_exprs.append(cleaned)
            
            # Ensure we have the right number of expressions
            while len(simplified_exprs) < num_expressions:
                simplified_exprs.append("")  # Empty placeholder for missing expressions
                
            return simplified_exprs
            
        except Exception as e:
            print(f"Error in LLM simplification: {e}")
            # Return empty strings if LLM call fails
            return [""] * num_expressions
    
    def _get_generalization_system_prompt(self):
        """Get system prompt focused on generalization."""
        return """You are a mathematics expert specializing in symbolic regression and model generalization.
Analyze the expressions provided and improve each one to enhance generalization to unseen data. 

You will work with mathematical expressions in DEAP format:
- add(a, b)
- mul(a, b)
- sub(a, b)
- div(a, b)
- neg(a)
- sin(a)
- cos(a)
- exp(a)

Important: Preserve the DEAP format exactly (keep function names unchanged).
Maximum number of expressions to return is 10.

Your goal is to produce expressions that will generalize better.
The goal is NOT just algebraic simplification, but creating expressions that capture the underlying patterns
without fitting to noise. Use insights from previous simplification examples, noting which changes led to
improvements in error rates.

Respond with ONLY the improved expressions, one per line, numbered as:
1. [improved expression 1]
2. [improved expression 2]
etc.
Make sure to return diverse set of expressions even if the given expressions are similar.
Do not include any other text, explanations or comments."""
    
    def _get_simplicity_system_prompt(self):
        """Get system prompt focused on simplicity."""
        return """You are a mathematics expert specializing in algebraic simplification.
Analyze the expressions provided and simplify each one to be as concise as possible.

You will work with mathematical expressions in DEAP format:
- add(a, b)
- mul(a, b)
- sub(a, b)
- div(a, b)
- neg(a)
- sin(a)
- cos(a)
- exp(a)

Important: Preserve the DEAP format exactly (keep function names unchanged).
Maximum number of expressions to return is 10.

Your goal is to produce maximally simplified expressions.
Focus intensely on reducing complexity.
Use insights from previous simplification examples, particularly noting simplifications
that led to shorter expressions while improving error rates.

Respond with ONLY the simplified expressions, one per line, numbered as:
1. [simplified expression 1]
2. [simplified expression 2]
etc.
Make sure to return diverse set of expressions even if the given expressions are similar.
Do not include any other text, explanations or comments."""

#     def _get_balanced_system_prompt(self):
#         return """You are a mathematics expert specializing in expression transformations and algebraic insights.
# Analyze the expressions provided and apply powerful mathematical transformations to find equivalent but potentially more effective forms.

# You will work with mathematical expressions in DEAP format:
# - add(a, b) 
# - mul(a, b)
# - sub(a, b)
# - div(a, b)
# - neg(a)
# - sin(a)
# - cos(a)
# - exp(a)

# Apply these transformation techniques:
# 1. Variable substitutions (e.g., let u = x^2 to simplify expressions with squares)
# 2. Functional transformations (e.g., converting between trigonometric and exponential forms)
# 3. Algebraic restructuring (e.g., factoring, distributing, completing squares)
# 4. Symmetry exploitation (e.g., recognizing patterns that can be expressed more compactly)

# Your goal is NOT just simplification but finding mathematically equivalent forms that might perform better or reveal underlying patterns. Look for opportunities to:
# - Replace complex operations with simpler ones
# - Identify mathematical identities that can be applied
# - Apply domain-specific mathematical knowledge
# - Reorganize terms to highlight mathematical structure

# Respond with ONLY the transformed expressions, one per line, numbered as:
# 1. [transformed expression 1]
# 2. [transformed expression 2]
# etc.

# Do not include any other text, explanations or comments."""
    
    def _get_balanced_system_prompt(self):
        """Get system prompt focused on balancing simplicity and generalization."""
        return """You are a mathematics expert specializing in optimizing mathematical expressions.
Analyze the expressions provided and improve each one to balance simplicity and generalization.

You will work with mathematical expressions in DEAP format:
- add(a, b)
- mul(a, b)
- sub(a, b)
- div(a, b)
- neg(a)
- sin(a)
- cos(a)
- exp(a)

Important: Preserve the DEAP format exactly (keep function names unchanged).
Maximum number of expressions to return is 10.

Your goal is to produce optimized expressions.
Seek expressions that are both simple enough to avoid overfitting but complex enough
to capture the underlying patterns. Use insights from previous simplification examples,
noting which types of changes led to better balance of simplicity and accuracy.

Respond with ONLY the optimized expressions, one per line, numbered as:
1. [optimized expression 1]
2. [optimized expression 2]
etc.
Make sure to return diverse set of expressions even if the given expressions are similar.
Do not include any other text, explanations or comments."""