from textgrad.engine.anthropic import ChatAnthropic
import os
from anthropic import Anthropic
import tiktoken

class ThinkingChatAnthropic(ChatAnthropic):
    """Extended ChatAnthropic engine with thinking parameter support."""
    
    def __init__(self, model_string="claude-3-7-sonnet-20250219", 
                 system_prompt="You are a helpful, creative, and smart assistant.",
                 thinking_enabled=True,
                 thinking_budget=16000):
        super().__init__(model_string=model_string, system_prompt=system_prompt)
        self.thinking_enabled = thinking_enabled
        self.thinking_budget = thinking_budget
        self.last_thinking = None
    
    def generate(self, prompt, system_prompt=None, temperature=0, max_tokens=60000, top_p=0.99):
        """Override generate to include thinking parameter."""
        sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
        
        # Check cache first
        # cache_or_none = self._check_cache(sys_prompt_arg + prompt)
        # if cache_or_none is not None:
        #     return cache_or_none
        
        # Configure thinking parameter
        thinking_config = None
        if self.thinking_enabled:
            thinking_config = {
                "type": "enabled",
                "budget_tokens": self.thinking_budget
            }
        
        # Implement retry logic
        max_retries = 3
        retry_count = 0
        backoff_factor = 2  # seconds
        
        while retry_count < max_retries:
            try:
                complete_response, thinking_content = self._stream_response(
                    content=prompt, 
                    system_prompt=sys_prompt_arg, 
                    thinking_config=thinking_config
                )
                
                # If we got here, streaming was successful
                # Calculate token counts
                import tiktoken
                encoder = tiktoken.get_encoding("cl100k_base")
                self.last_thinking_tokens = len(encoder.encode(thinking_content))
                self.last_thinking = thinking_content
                self.last_completion_tokens = len(encoder.encode(complete_response))
                self.last_total_tokens = self.last_thinking_tokens + self.last_completion_tokens
                
                return complete_response
                
            except Exception as e:
                retry_count += 1
                if retry_count == max_retries:
                    # If we reach max retries, raise the exception
                    raise e
                wait_time = backoff_factor * (2 ** (retry_count - 1))  # Exponential backoff
                
                print(f"\nStreaming attempt {retry_count} failed: {str(e)}")
                if retry_count < max_retries:
                    print(f"Retrying in {wait_time} seconds...")
                    import time
                    time.sleep(wait_time)
                else:
                    print("Maximum retries reached. Raising exception.")
                    raise RuntimeError(f"Failed to get complete response after {max_retries} attempts: {str(e)}")
    def _stream_response(self, content, system_prompt, thinking_config):
        """
        Stream the response without a timeout.
        
        Returns:
            tuple: (complete_response, thinking_content)
        """
        complete_response = ""
        thinking_content = ""
        
        with self.client.messages.stream(
            model=self.model_string,
            max_tokens=60000,
            system=system_prompt,
            thinking=thinking_config,
            messages=[
                {"role": "user", "content": content}
            ]
        ) as stream:
            print("\nStarting streaming response...")
            
            for event in stream:
                if event.type == "content_block_start":
                    print(f"\nStarting {event.content_block.type} block...")
                
                elif event.type == "content_block_delta":
                    if event.delta.type == "thinking_delta":
                        thinking_content += event.delta.thinking
                    elif event.delta.type == "text_delta":
                        complete_response += event.delta.text
                
                elif event.type == "content_block_stop":
                    print("\nBlock complete.")
        
        print("\nStreaming completed successfully")
        return complete_response, thinking_content
    def get_last_thinking_tokens(self):
        """Get token count from the last thinking process."""
        return self.last_thinking_tokens