import asyncio
import random
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate.inference import prepare_pippy
import numpy as np
import logging
import aio_pika
import json
from fundcc import programs_database
from fundcc.code_manipulation import text_to_function
from typing import List
from fundcc.profiling import sync_time_execution, sync_track_memory, async_track_memory, async_time_execution
from openai import AzureOpenAI
import os
import re
import ast

logger = logging.getLogger('main_logger')

class LLM_model:
    def __init__(self, samples_per_prompt: int, model="gpt-4o-mini"):
        self.samples_per_prompt = samples_per_prompt
        self.model = model
        logger.debug("In LLM")

        # Initialize the Azure OpenAI client
        try: 
            self.client = AzureOpenAI(
                api_version=os.getenv('AZURE_OPENAI_API_VERSION', '2023-07-01-preview'),  # Default API version if not set
                azure_endpoint=os.getenv('AZURE_OPENAI_ENDPOINT'),  # Azure OpenAI endpoint, e.g., "https://<resource-name>.openai.azure.com/"
                api_key=os.getenv('AZURE_OPENAI_API_KEY')  # Azure OpenAI API key
            )
        except Exception as e: 
            logger.error(f"Failed to initialize Azure OpenAI client: {e}")

        # Initialize counters for tracking usage
        self.total_requests = 0
        self.total_prompt_tokens = 0
        self.total_completion_tokens = 0
        self.total_tokens = 0
        self.total_cost = 0.0  # Track total cost

    def calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
        """
        Calculate the cost based on prompt and completion tokens.
        Rates are defined for the specific model being used.
        """
        # Rates for gpt-4o-mini (adjust if your model has different rates)
        prompt_rate = 0.150 / 1_000_000  # $0.150 per 1M input tokens
        completion_rate = 0.600 / 1_000_000  # $0.600 per 1M output tokens

        # Calculate cost for this request
        prompt_cost = prompt_tokens * prompt_rate
        completion_cost = completion_tokens * completion_rate

        # Return total cost
        return prompt_cost + completion_cost

    def draw_sample(self, prompt: str) -> list:
        """
        Generate a sample response from the LLM based on the provided prompt.
        """
        try:
            # Using the updated client for Azure OpenAI with `model`
            response = self.client.chat.completions.create(
                model=self.model,  
                messages=[
                    {"role": "system", "content": "You are a helpful assistant specializing in Python programming."},
                    {"role": "user", "content": prompt}
                ],
                max_tokens=800,
                n= self.samples_per_prompt
            )

            # Log the entire response
            logger.debug(f"Full response: {response}")

            # Extract usage details from the response
            usage = response.usage # contains information on token usage for the completion request to compute cost
            self.total_requests += 1
            self.total_prompt_tokens += usage.prompt_tokens
            self.total_completion_tokens += usage.completion_tokens
            self.total_tokens += usage.total_tokens

            # Calculate the cost for this request
            cost = self.calculate_cost(usage.prompt_tokens, usage.completion_tokens)
            self.total_cost += cost

            # Log the response, tokens, and cost
            logger.debug(f"Tokens used in this request: prompt={usage.prompt_tokens}, completion={usage.completion_tokens}, total={usage.total_tokens}")
            logger.debug(f"Cost for this request: ${cost:.6f}")
            logger.debug(f"Total cost so far: ${self.total_cost:.6f}")
            logger.debug(f"Total requests so far: {self.total_requests}")
            logger.debug(f"Total tokens used so far: prompt={self.total_prompt_tokens}, completion={self.total_completion_tokens}, total={self.total_tokens}")

            # Retrieve the generated text from the response
            generated_responses = [choice.message.content for choice in response.choices]
            logger.debug(f"Generated response from gpt mini is {generated_responses}")
        
            # Return the list of message content (the generated text)
            return generated_responses

        except Exception as e:
            logger.error(f"Unexpected error during draw_sample: {str(e)}")
            return []


class Sampler:
    def __init__(self, connection, channel, sampler_queue, evaluator_queue, config, config_prompt, local_id):
        self.connection = connection
        self.channel = channel
        self.sampler_queue = sampler_queue
        self.evaluator_queue = evaluator_queue
        self.config = config
        self._llm = LLM_model(samples_per_prompt=self.config.samples_per_prompt)
        self.config_prompt = config_prompt
        self.prefetch_count = 10
        self.local_id=local_id
        self.reasoning=config_prompt.reasoning
        self.challenge_vtcodes=config_prompt.challenge_vtcodes

    async def shutdown(self):
        logger.info(f"Sampler {self.local_id}: Initiating shutdown process.")

        # Step 1: Stop the consumer properly
        if hasattr(self, "consumer") and self.consumer:
            self.consumer = None  # Exit iterator to stop consuming
            logger.info(f"Sampler {self.local_id}: Consumer stopped.")

        # Step 2: Close RabbitMQ connections properly
        if self.channel and not self.channel.is_closed:
            try:
                await self.channel.close()
                logger.info(f"Sampler {self.local_id}: RabbitMQ channel closed.")
            except Exception as e:
                logger.warning(f"Sampler {self.local_id}: Error closing channel: {e}")

        if self.connection and not self.connection.is_closed:
            try:
                await self.connection.close()
                logger.info(f"Sampler {self.local_id}: RabbitMQ connection closed.")
            except Exception as e:
                logger.warning(f"Sampler {self.local_id}: Error closing connection: {e}")

        logger.info(f"Sampler {self.local_id}: Shutdown process complete.")

    async def consume_and_process(self):
        try:
            await self.channel.set_qos(prefetch_count=self.prefetch_count)

            async with self.sampler_queue.iterator() as stream:
                async for message in stream:
                    async with message.process():
                        try:
                            gpu_time = 0
                            data = json.loads(message.body.decode())
                            prompt_data = data["prompt"]
                            prompt = programs_database.Prompt.deserialize(prompt_data)
                            total_registered_programs = data.get("total_registered_programs", 0)

                            responses = self._llm.draw_sample(prompt.code)  
                            logger.debug(f"Responses: {responses}")

                            for response in responses:
                                message_data = {
                                    "sample": self.extract_latest_priority_function_body(response),
                                    "island_id": prompt.island_id,
                                    "version_generated": prompt.version_generated,
                                    "expected_version": prompt.expected_version,
                                    "gpu_time": gpu_time,
                                }
                                serialized_message = json.dumps(message_data)

                                try:
                                    await self.channel.default_exchange.publish(
                                        aio_pika.Message(body=serialized_message.encode()),
                                        routing_key='evaluator_queue'
                                    )
                                    logger.debug("Successfully published message to evaluator_queue")
                                except Exception as e:
                                    logger.error(f"Error publishing to evaluator_queue: {e}")

                        except asyncio.CancelledError:
                            logger.warning("Sampler: consume_and_process was cancelled.")
                        except Exception as e:
                            logger.error(f"Sampler: Error processing message: {e}")

        except aio_pika.exceptions.ChannelClosed as e:
            logger.warning(f"Sampler {self.local_id}: Channel closed by RPC timeout. {e}")  
        except aio_pika.exceptions.AMQPError as e:
            logger.error(f"Sampler {self.local_id}: AMQP error occurred: {e}")
        except asyncio.CancelledError:
            logger.info(f"Sampler {self.local_id}: Consumer task cancelled while iterating messages.")
        except Exception as e:
            logger.warning(f"Sampler {self.local_id}: Unexpected error while consuming messages: {e}")
        finally:
            logger.info(f"Sampler {self.local_id}: Shutting down due to exception or completion.")
            await self.shutdown()


    def extract_latest_priority_function_body(self, llm_output: str) -> str:
        """
        Extracts the function body of the latest `priority_v{number}` function from LLM output.
        It prioritizes Python-labeled code blocks first, then falls back to generic blocks
        only if they contain a function starting with `priority_v{number}`.

        Args:
            llm_output (str): The raw LLM output containing markdown code blocks.

        Returns:
            str: The source code of the function body of the highest version priority function,
                or None if not found.
        """
        if not self.reasoning and not self.challenge_vtcodes:
            return llm_output

        # Regex pattern to match function names like "priority_v{number}"
        func_name_pattern = re.compile(r"^priority_v(\d+)$")

        # Extract Python-specific fenced code blocks first
        python_blocks = re.findall(r"```python\s*\n(.*?)```", llm_output, re.DOTALL)

        # If no Python-labeled blocks were found, extract generic ``` ... ``` blocks
        if not python_blocks:
            generic_blocks = re.findall(r"```\s*\n(.*?)```", llm_output, re.DOTALL)
            valid_generic_blocks = [block for block in generic_blocks if re.search(r"^\s*def\s+priority_v\d+\s*\(", block, re.MULTILINE)]
        else:
            valid_generic_blocks = []  # If Python blocks exist, no need for generic fallback

        # Use Python blocks if available; otherwise, use valid generic ones
        valid_code_blocks = python_blocks if python_blocks else valid_generic_blocks

        if not valid_code_blocks:
            logger.warning("No valid Python code blocks containing priority functions found.")
            return None

        latest_version = -1
        latest_func_code = None

        logger.info(f"Extracted candidate code blocks: {valid_code_blocks}")

        # Process each valid code block
        for block in valid_code_blocks:
            try:
                module = ast.parse(block)
            except SyntaxError as e:
                logger.error(f"Skipping invalid Python block due to AST parsing error: {e}")
                continue

            # Iterate over each top-level function definition
            for node in module.body:
                if isinstance(node, ast.FunctionDef):
                    m = func_name_pattern.match(node.name)
                    if m:
                        try:
                            version = int(m.group(1))
                        except ValueError:
                            continue  # Skip if version number isn't valid

                        if version > latest_version:
                            if not hasattr(node, "end_lineno"):  # Ensure Python 3.8+
                                logger.warning("Incomplete function definition (missing end_lineno); skipping.")
                                continue

                            # Extract only the `priority_v{number}` function
                            func_lines = block.splitlines()[node.lineno - 1 : node.end_lineno]
                            func_source = "\n".join(func_lines)

                            latest_version = version
                            latest_func_code = func_source  # Store only the matched function

        if latest_func_code is None:
            logger.warning("No complete priority function definition found.")
            return None

        # Use text_to_function to extract only the function body
        func_obj = text_to_function(latest_func_code)

        return func_obj.body  # Return only the function body, excluding the header and docstring


if __name__ == "__main__":
    pass
