from serpapi import GoogleSearch
import time
import requests

from typing import Any, Dict, List
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage

def get_user_question(messages: List[AnyMessage]) -> str:
    for message in messages:
        if isinstance(message, HumanMessage):
            return message.content


def get_search_results(query: str, api_key: str, num_results: int = 15, retry_attempt : int = 3):
    params = {
        "engine": "google",
        "q": query,
        "api_key": api_key,
        "num": num_results,
    }
    for i in range(retry_attempt):
        try:
            search = GoogleSearch(params)
            # debug: 
            # print(search.get_dict())
            organic_results = search.get_dict().get("organic_results", [])
            if not organic_results:
                return (f"No results found for query: {query}")
            parsed_results = []
            for result in organic_results:
                parsed_result = {
                    "title": result.get("title"),
                    "link": result.get("link"),
                    "snippet": result.get("snippet"),
                    "displayed_link": result.get("displayed_link"),
                }
                parsed_results.append(parsed_result)
            return parsed_results
        except Exception as e:
            print(f"Attempt {i+1} failed: {e}")
            if i < retry_attempt - 1:
                time.sleep(2)
            else:
                print("All retries failed.")
                return ("Connection error to the search engine. Please try again later.")

def visit(url: str, api_key: str):
    request_url = "https://r.jina.ai/" + url
    headers = {"Authorization": f"Bearer {api_key}"}
    response = requests.get(request_url, headers=headers)
    if response.status_code == 200:
        return response.text
    else:
        return f"Website Visit Error: {response.status_code}"


def truncate_webpage_content(content: str, max_length: int, topics: List[str] = None) -> str:
    """
    Intelligently truncate webpage content to fit within context limits.
    
    Args:
        content: The full webpage content
        max_length: Maximum number of characters to keep
        topics: Optional list of topics to prioritize relevant sections
        
    Returns:
        Truncated content that fits within the limit
    """
    if len(content) <= max_length:
        return content
    
    # If we have topics, try to find and prioritize relevant sections
    if topics:
        # Split content into paragraphs
        paragraphs = content.split('\n\n')
        relevant_paragraphs = []
        other_paragraphs = []
        
        # Score paragraphs based on topic relevance
        for paragraph in paragraphs:
            if any(topic.lower() in paragraph.lower() for topic in topics):
                relevant_paragraphs.append(paragraph)
            else:
                other_paragraphs.append(paragraph)
        
        # Start with relevant paragraphs
        truncated_content = ""
        
        # Add relevant paragraphs first
        for paragraph in relevant_paragraphs:
            if len(truncated_content) + len(paragraph) + 2 <= max_length:  # +2 for \n\n
                truncated_content += paragraph + "\n\n"
            else:
                # Add partial paragraph if it fits
                remaining_space = max_length - len(truncated_content) - 50  # Leave some buffer
                if remaining_space > 100:  # Only add if meaningful content can fit
                    truncated_content += paragraph[:remaining_space] + "...\n\n"
                break
        
        # Add other paragraphs if space remains
        for paragraph in other_paragraphs:
            if len(truncated_content) + len(paragraph) + 2 <= max_length:
                truncated_content += paragraph + "\n\n"
            else:
                # Add partial paragraph if it fits
                remaining_space = max_length - len(truncated_content) - 50
                if remaining_space > 100:
                    truncated_content += paragraph[:remaining_space] + "...\n\n"
                break
        
        return truncated_content.strip()
    
    else:
        # Simple truncation - take beginning and end
        half_length = (max_length - 100) // 2  # Reserve 100 chars for "..." separator
        beginning = content[:half_length]
        end = content[-half_length:]
        
        # Try to cut at word boundaries
        if ' ' in beginning[-50:]:
            beginning = beginning[:beginning.rfind(' ', -50)]
        if ' ' in end[:50]:
            end = end[end.find(' ', 50):]
            
        return f"{beginning}\n\n... [Content truncated due to length] ...\n\n{end}"

def split_webpage_content(content: str, max_length: int, overlap: int = 0) -> List[str]:
    """
    Split webpage content into parts of a given length, splitting at natural boundaries.
    
    Args:
        content: The webpage content to split
        max_length: Maximum length of each chunk
        overlap: Number of characters to overlap between chunks (default: 0)
        
    Returns:
        List of content chunks split at natural boundaries
    """
    if len(content) <= max_length:
        return [content]
    
    chunks = []
    current_pos = 0
    
    while current_pos < len(content):
        # Calculate the end position for this chunk
        end_pos = min(current_pos + max_length, len(content))
        
        # If this is the last chunk, take everything remaining
        if end_pos >= len(content):
            chunks.append(content[current_pos:])
            break
        
        # Find the best split point within the chunk
        chunk_text = content[current_pos:end_pos]
        split_point = _find_best_split_point(chunk_text)
        
        # If we found a good split point, use it
        if split_point > 0:
            actual_end = current_pos + split_point
            chunks.append(content[current_pos:actual_end])
            # Move to next position, accounting for overlap
            current_pos = max(actual_end - overlap, current_pos + 1)
        else:
            # Fallback to character-based split if no natural boundary found
            chunks.append(content[current_pos:end_pos])
            current_pos = max(end_pos - overlap, current_pos + 1)
    
    return chunks


def _find_best_split_point(text: str) -> int:
    """
    Find the best point to split text, preferring natural boundaries.
    
    Returns the position within the text to split at, or 0 if no good split found.
    """
    # Priority order for split points (from best to worst)
    split_patterns = [
        # Double newlines (paragraph breaks)
        '\n\n',
        # Single newlines at end of sentences
        '.\n',
        '!\n', 
        '?\n',
        # Headers and list items
        '\n#',
        '\n-',
        '\n*',
        '\n1.',
        '\n2.',
        '\n3.',
        '\n4.',
        '\n5.',
        # End of sentences with space
        '. ',
        '! ',
        '? ',
        # Commas and semicolons (less preferred)
        ', ',
        '; ',
        # Single newlines (least preferred natural boundary)
        '\n',
        # Spaces (fallback)
        ' ',
    ]
    
    # Look for split points starting from the end of the text
    # This ensures we get the longest possible chunk while staying under limit
    for pattern in split_patterns:
        # Find the last occurrence of this pattern
        last_pos = text.rfind(pattern)
        if last_pos > len(text) * 0.5:  # Only use if split point is in latter half
            return last_pos + len(pattern)
    
    return 0  # No good split point found