# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Class for evaluating programs proposed by the Sampler."""
import ast
from typing import Any
import copy
import logging
from fundcc import code_manipulation
from fundcc import sandbox
from pathlib import Path
import json
import aio_pika
import sys
import asyncio
import concurrent.futures  
from concurrent.futures import ProcessPoolExecutor, as_completed 
from multiprocessing import Manager  
import psutil
import shutil
from fundcc.profiling import async_time_execution
import time

logger = logging.getLogger('main_logger')

class _FunctionLineVisitor(ast.NodeVisitor):
  """Visitor that finds the last line number of a function with a given name."""
  def __init__(self, target_function_name: str) -> None:
    self._target_function_name: str = target_function_name
    self._function_end_line: int | None = None

  def visit_FunctionDef(self, node: Any) -> None:  
    """Collects the end line number of the target function."""
    if node.name == self._target_function_name:
      self._function_end_line = node.end_lineno
    self.generic_visit(node)

  @property
  def function_end_line(self) -> int:
    """Line number of the final line of function `target_function_name`."""
    assert self._function_end_line is not None
    return self._function_end_line

def _trim_function_body(generated_code: str) -> str:
  """Extracts the body of the generated function, trimming anything after it."""
  if not generated_code:
    return ''
  # Wrap generated code in a fake function header.
  code = f'def fake_function_header():\n{generated_code}'
  tree = None
  # Keep trimming code from the end until parsing succeeds.
  while tree is None:
    try:
      tree = ast.parse(code)
    except SyntaxError as e:
      code = '\n'.join(code.splitlines()[:e.lineno - 1])
  if not code:
    return ''
  visitor = _FunctionLineVisitor('fake_function_header')
  visitor.visit(tree)
  body_lines = code.splitlines()[1:visitor.function_end_line]
  return '\n'.join(body_lines) + '\n\n'

def _sample_to_program(
    generated_code: str,
    version_generated: int | None,
    template: code_manipulation.Program,
    function_to_evolve: str,
) -> tuple[code_manipulation.Function, str]:
  """Returns the compiled generated function and the full runnable program."""
  body = _trim_function_body(generated_code) 
  if version_generated is not None:
    body = code_manipulation.rename_function_calls(
        body,
        f'{function_to_evolve}_v{version_generated}',
        function_to_evolve)
  program = copy.deepcopy(template)
  evolved_function = program.get_function(function_to_evolve)
  evolved_function.body = body
  return evolved_function, str(program)

def run_evaluation(sandbox, program, function_to_run, input, timeout_seconds, call_count, call_count_lock):
    with call_count_lock:
        count = call_count.value
        call_count.value += 1
    result, runs_ok, cpu_time, call_data_folder, input_path, error_file = sandbox.run(
        program, function_to_run, input, timeout_seconds, count)
    return result, runs_ok, cpu_time, call_data_folder, input_path, error_file

class Evaluator:
    def __init__(self, connection, channel, evaluator_queue, database_queue, template, function_to_evolve, function_to_run, inputs, sandbox_base_path, timeout_seconds, local_id, target_solutions):
        self.connection = connection
        self.channel = channel
        self.evaluator_queue = evaluator_queue
        self.database_queue = database_queue
        self.template = template
        self.function_to_evolve = function_to_evolve
        self.function_to_run = function_to_run
        self.inputs = inputs
        self.timeout_seconds = timeout_seconds
        self.local_id = local_id
        self.target_solutions = target_solutions
        self.manager = Manager()
        self.call_count = self.manager.Value('i', 0)
        self.call_count_lock = self.manager.Lock()
        self.sandbox = sandbox.ExternalProcessSandbox(
            base_path=sandbox_base_path,
            timeout_secs=timeout_seconds,
            python_path=sys.executable,
            local_id=self.local_id)
        self.executor = ProcessPoolExecutor(max_workers=2)
        self.cumulative_cpu_time = 0.0  # Track total CPU time.
        self.cpu_time_lock = self.manager.Lock()  # Lock to protect updates to CPU time.


    def _track_cpu_time(self):
        """Tracks CPU time for all child processes and adds to the cumulative total."""
        parent = psutil.Process()
        with self.cpu_time_lock:
            for child in parent.children(recursive=True):
                try:
                    cpu_times = child.cpu_times()
                    self.cumulative_cpu_time += cpu_times.user + cpu_times.system
                except psutil.NoSuchProcess:
                    pass

    async def shutdown(self):
        logger.info(f"Evaluator {self.local_id}: Initiating shutdown process.")

        # Step 1: Stop the consumer properly
        if hasattr(self, "consumer") and self.consumer:
            self.consumer = None  # Exit iterator to stop consuming
            logger.info(f"Evaluator {self.local_id}: Consumer stopped.")

        # Step 2: Close RabbitMQ connections properly
        if self.channel and not self.channel.is_closed:
            try:
                await self.channel.close()
                logger.info(f"Evaluator {self.local_id}: RabbitMQ channel closed.")
            except Exception as e:
                logger.warning(f"Evaluator {self.local_id}: Error closing channel: {e}")

        if self.connection and not self.connection.is_closed:
            try:
                await self.connection.close()
                logger.info(f"Evaluator {self.local_id}: RabbitMQ connection closed.")
            except Exception as e:
                logger.warning(f"Evaluator {self.local_id}: Error closing connection: {e}")

        # Step 3: Ensure child processes are terminated
        parent = psutil.Process()
        children = parent.children(recursive=True)
        if children:
            for child in children:
                logger.info(f"Evaluator {self.local_id}: Terminating child process PID {child.pid}")
                child.terminate()
            gone, still_alive = psutil.wait_procs(children, timeout=5)
            if still_alive:
                for p in still_alive:
                    logger.warning(f"Evaluator {self.local_id}: PID {p.pid} did not terminate, forcing kill.")
                    p.kill()
                    try:
                        p.wait(timeout=2)
                    except psutil.TimeoutExpired:
                        logger.error(f"Evaluator {self.local_id}: PID {p.pid} did not exit after forced kill.")

        logger.info(f"Evaluator {self.local_id}: Shutdown process complete.")


    async def consume_and_process(self):
        try:
            async with self.channel:
                await self.channel.set_qos(prefetch_count=1)
                async with self.evaluator_queue.iterator() as stream:
                    self.consumer = stream  # Store consumer reference
                    async for message in stream:
                        fetch_start_time = time.perf_counter()
                        async with message.process():
                            fetch_end_time = time.perf_counter()
                            fetch_duration = fetch_end_time - fetch_start_time
                            logger.debug(f"Time to fetch message from queue: {fetch_duration:.6f} seconds")
                            try:
                                await asyncio.wait_for(self.process_message(message), timeout=60)
                            except asyncio.TimeoutError:
                                logger.warning("Processing message timed out.")
                            except asyncio.CancelledError:
                                logger.warning("Sampler: consume_and_process was cancelled.")
                            except Exception as e:
                                logger.error(f"Evaluator: Error while processing message: {e}")
        except aio_pika.exceptions.ChannelClosed as e:
            logger.warning(f"Evaluator {self.local_id}: Channel closed by RPC timeout. {e}")  
        except aio_pika.exceptions.AMQPError as e:
            logger.error(f"Evaluator {self.local_id}: AMQP error occurred: {e}")
        except asyncio.CancelledError:
            logger.info(f"Evaluator {self.local_id}: Consumer task cancelled while iterating messages.")
        except Exception as e:
            logger.warning(f"Evaluator {self.local_id}: Unexpected error while consuming messages: {e}")
        finally:
            logger.info(f"Evaluator {self.local_id}: Shutting down due to exception or completion.")
            await self.shutdown()


    async def process_message(self, message: aio_pika.IncomingMessage):
        call_folders_to_cleanup = []  # List to track created folders
        call_files_to_cleanup = []  # List to track created folders
        hash_value=None
        call_data_folder=None
        try:
            raw_data = message.body.decode()
            data = json.loads(raw_data)
            logger.debug(f"Data is {data}")
            logger.debug(f"Evaluator: Starts to analyze generated continuation of def priority: {data['sample']}")

            # Deserialize GPU time
            gpu_time = data.get("gpu_time", 0.0)
            logger.debug(f"Received GPU time from Sampler: {gpu_time} seconds")

            # Process the new function from the generated code
            new_function, program = _sample_to_program(data["sample"], data.get("version_generated"), self.template, self.function_to_evolve)
            logger.debug(f"New function body is {new_function.body}")
            logger.debug(f"New function is {program}")

            tasks = {}
            
            if new_function.body not in [None, '']:
                # Submit each test input as a task for multiprocessing
                tasks = {self.executor.submit(run_evaluation, self.sandbox, program, self.function_to_run, input, self.timeout_seconds, self.call_count, self.call_count_lock): input for input in self.inputs}
            else:
                logger.info("New function body is None or empty. Skipping execution but publishing 'return'.")
                result = ("return", data.get('island_id', None), {}, data['expected_version'], self.cumulative_cpu_time, gpu_time, False, {})
                await self.publish_to_database(result, message, hash_value)  # Publish "return" result
                return  # Early return after publishing

            scores_per_test = {}
            vt_overlap_per_test = {}
            # Waiting for results from all test inputs
            for future in as_completed(tasks):
                input = tasks[future]
                try:
                    test_output, runs_ok, cpu_time,  call_data_folder, input_path, error_file= future.result(timeout=self.timeout_seconds)
                    call_folders_to_cleanup.append(call_data_folder)
                    call_files_to_cleanup.append(input_path)
                    call_files_to_cleanup.append(error_file)

                    # Accumulate CPU time
                    with self.cpu_time_lock:
                        self.cumulative_cpu_time += cpu_time
                        
                    if runs_ok and test_output[0] is not None:

                        if len(test_output) == 3:
                            test_score, hash_val, vt_overlap = test_output
                        else:
                            test_score, hash_val = test_output
                            vt_overlap = None  # Default if not present

                        scores_per_test[input] = test_output[0]
                        if hash_val is not None:
                            hash_value=test_output[1]
                        if vt_overlap is not None:
                            vt_overlap_per_test[input] = vt_overlap
                        logger.debug(f"Evaluator: scores_per_test {scores_per_test}, hash_val {hash_val} and vt_overlap {vt_overlap}")
                except concurrent.futures.TimeoutError:
                    logger.warning(f"Task for input {input} timed out.")
                except concurrent.futures.CancelledError:
                    logger.warning(f"Task for input {input} was cancelled.")
                except Exception as e:
                    # Catch any other exceptions
                    logger.error(f"Error during task execution for input {input}: {e}")
            
            if self.target_solutions: 
                found_optimal_solution = all(
                    scores_per_test.get(dim, 0) >= self.target_solutions.get(dim, float("inf"))
                    for dim in self.target_solutions)            
            else: 
                found_optimal_solution = False

            # Prepare the result for publishing
            if len(scores_per_test) == len(self.inputs) and any(score != 0 for score in scores_per_test.values()):
                result = (new_function, data.get('island_id', None), scores_per_test, data['expected_version'], self.cumulative_cpu_time, gpu_time, found_optimal_solution, vt_overlap_per_test)
                logger.debug(f"Scores are {scores_per_test}")
            else:
                result = ("return", data.get('island_id', None), {}, data['expected_version'], self.cumulative_cpu_time, gpu_time, False, vt_overlap_per_test)

            # Publish the result
            await self.publish_to_database(result, message, hash_value)

            # Reset cumulative CPU time after publishing
            with self.cpu_time_lock:
                self.cumulative_cpu_time = 0.0

        except Exception as e:
            logger.error(f"Error in process_message: {e}")
        
        finally:
            # Cleanup: Delete the call_data_folder after a delay
            await asyncio.sleep(1)  # Optional delay, adjust if needed
            if call_data_folder and call_data_folder.exists():
                shutil.rmtree(call_data_folder)

    async def publish_to_database(self, result, message, hash_value):
        try:
            function, island_id, scores_per_test, expected_version, cpu_time, gpu_time, found_optimal_solution, vt_overlap_per_test = result
            serialized_result = {
                "new_function": function.serialize() if hasattr(function, 'serialize') else str(function),
                "island_id": island_id,
                "scores_per_test": {str(key): value for key, value in scores_per_test.items()},
                "expected_version": expected_version,
                "hash_value": hash_value,
                "cpu_time": cpu_time,
                "gpu_time": gpu_time,
                "found_optimal_solution": found_optimal_solution, 
                "vt_overlap_per_test": {str(key): value for key, value in vt_overlap_per_test.items()},
            }
            message_body = json.dumps(serialized_result)
            publish_start_time = time.perf_counter()
            await self.channel.default_exchange.publish(
                aio_pika.Message(body=message_body.encode()),
                routing_key='database_queue'
            )
            publish_end_time = time.perf_counter()
            publish_duration = publish_end_time - publish_start_time
            logger.debug(f"Time to publish message to queue: {publish_duration:.6f} seconds")
            logger.debug(f"Evaluator: Successfully published to database for island_id {island_id}.")
        except aio_pika.exceptions.ChannelClosed as e:
            # RabbitMQ closed the channel due to RPC timeout — ignore it
            logger.debug(f"Evaluator {self.local_id}: ChannelClosed when publishing; ignoring.")
            return
        except Exception as e:
            logger.error(f"Evaluator: Problem in publishing to database for island_id {island_id}: {e}")
            raise
