import asyncio
import logging
from logging import FileHandler
from multiprocessing import current_process, Process
import aio_pika
from yarl import URL
import signal
import os
import socket
import argparse
import json
import importlib
from fundcc.scaling_utils import ResourceManager
from fundcc import code_manipulation, evaluator

# Prevent tokenizer parallelism issues.
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def load_config(config_path):
    """
    Dynamically load a configuration module from a specified file path.
    """
    spec = importlib.util.spec_from_file_location("config", config_path)
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)
    return config_module.Config()


class TaskManager:
    def __init__(self, specification: str, inputs, config, log_dir, target_solutions):
        self.specification = specification
        self.inputs = inputs
        self.config = config
        self.logger = self.initialize_logger(log_dir)
        self.evaluator_processes = []
        self.tasks = []
        self.channels = []
        self.queues = []
        self.connection = None
        # In evaluator mode, we run CPU only.
        self.resource_manager = ResourceManager(log_dir=log_dir, cpu_only=True)
        self.target_solutions = target_solutions

    def initialize_logger(self, log_dir):
        logger = logging.getLogger('main_logger')
        logger.setLevel(logging.DEBUG)
        os.makedirs(log_dir, exist_ok=True)
        hostname = socket.gethostname()
        log_file_name = f'eval_{hostname}.log'
        log_file_path = os.path.join(log_dir, log_file_name)
        handler = FileHandler(log_file_path, mode='w')
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.propagate = False
        return logger

    async def main_task(self, enable_scaling=True):
        # Connect to RabbitMQ; try using vhost first, then fall back.
        try:
            amqp_url = URL(
                f'amqp://{self.config.rabbitmq.username}:{self.config.rabbitmq.password}@'
                f'{self.config.rabbitmq.host}:{self.config.rabbitmq.port}/{self.config.rabbitmq.vhost}'
            ).update_query(heartbeat=300)
            evaluator_connection = await aio_pika.connect_robust(amqp_url)
        except Exception:
            self.logger.info("No vhost configured, connecting without.")
            amqp_url = URL(
                f'amqp://{self.config.rabbitmq.username}:{self.config.rabbitmq.password}@'
                f'{self.config.rabbitmq.host}:{self.config.rabbitmq.port}/'
            ).update_query(heartbeat=300)
            evaluator_connection = await aio_pika.connect_robust(amqp_url)

        # Start resource logging.
        resource_logging_task = asyncio.create_task(
            self.resource_manager.log_resource_stats_periodically(interval=60)
        )
        self.tasks = [resource_logging_task]

        pid = os.getpid()
        self.logger.info(f"Main_task is running in process with PID: {pid}.")

        try:
            self.template = code_manipulation.text_to_program(self.specification)
            function_to_evolve = 'priority'

            # Start initial evaluator processes.
            self.start_initial_processes(self.template, function_to_evolve, amqp_url)

            self.logger.info("Creating connection for scaling logic...")
            evaluator_channel = await evaluator_connection.channel()

            evaluator_queue = await evaluator_channel.declare_queue(
                "evaluator_queue",
                durable=False,
                auto_delete=True,
            )
            self.logger.info("evaluator_queue declared for scaling logic.")

            if enable_scaling:
                scaling_task = asyncio.create_task(
                    self.resource_manager.run_scaling_loop(
                        evaluator_queue=evaluator_queue,
                        sampler_queue=None,
                        evaluator_processes=self.evaluator_processes,
                        sampler_processes=None,
                        evaluator_function=self.evaluator_process,
                        sampler_function=None,
                        evaluator_args=(self.template, self.inputs, amqp_url),
                        sampler_args=None,
                        max_evaluators=args.max_evaluators,
                        max_samplers=None,
                        check_interval=args.check_interval,
                    )
                )
                self.tasks.append(scaling_task)

            await asyncio.gather(*self.tasks)
        except Exception as e:
            self.logger.error(f"Exception occurred in main_task: {e}")

    def start_initial_processes(self, template, function_to_evolve, amqp_url):
        amqp_url = str(amqp_url)
        # Start initial evaluator processes.
        for i in range(self.config.num_evaluators):
            proc = Process(
                target=self.evaluator_process,
                args=(template, self.inputs, amqp_url),
                name=f"Evaluator-{i}"
            )
            proc.start()
            self.logger.info(f"Started Evaluator Process {i} with PID: {proc.pid}")
            self.evaluator_processes.append(proc)

    def evaluator_process(self, template, inputs, amqp_url):
        local_id = current_process().pid
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        connection = None
        channel = None
        evaluator_task = None

        async def run_evaluator():
            try:
                self.logger.info(f"Evaluator {local_id}: Connecting to RabbitMQ.")
                connection = await aio_pika.connect_robust(amqp_url, timeout=300)
                channel = await connection.channel()
                
                evaluator_queue = await channel.declare_queue("evaluator_queue", durable=False, auto_delete=True)
                database_queue = await channel.declare_queue("database_queue", durable=False, auto_delete=True)

                evaluator_instance = evaluator.Evaluator(
                    connection, channel, evaluator_queue, database_queue,
                    template, 'priority', 'evaluate', inputs, timeout_seconds=self.config.evaluator.timeout, local_id=local_id, sandbox_base_path=args.sandbox_base_path, target_solutions=self.target_solutions
                )

                await evaluator_instance.consume_and_process()

            except aio_pika.exceptions.AMQPConnectionError as e:
                self.logger.error(f"Evaluator {local_id}: Connection lost, attempting to reconnect. Error: {e}")
                await asyncio.sleep(5)  # Wait and retry
                return await run_evaluator()  # Restart evaluator

            except asyncio.CancelledError:
                self.logger.warning(f"Evaluator {local_id}: Task was cancelled. Ignoring...")
            
            except Exception as e:
                self.logger.error(f"Evaluator {local_id}: Unhandled error: {e}")

            finally:
                # Ensure cleanup happens safely
                try:
                    if channel and not channel.is_closed:
                        self.logger.info(f"Evaluator {local_id}: Closing channel...")
                        await asyncio.shield(channel.close())
                except Exception as e:
                    self.logger.warning(f"Evaluator {local_id}: Failed to close channel: {e}")

                try:
                    if connection and not connection.is_closed:
                        self.logger.info(f"Evaluator {local_id}: Closing connection...")
                        await asyncio.shield(connection.close())
                except Exception as e:
                    self.logger.warning(f"Evaluator {local_id}: Failed to close connection: {e}")

                self.logger.info(f"Evaluator {local_id}: Cleanup done, exiting.")
                loop.stop()

        def shutdown_callback():
            self.logger.info(f"Evaluator {local_id}: Sending shutdown exception...")
            for task in asyncio.all_tasks():
                task.cancel()
            asyncio.create_task(run_evaluator())  # This will raise a CancelledError in tasks.

        loop.add_signal_handler(signal.SIGTERM, shutdown_callback)
        loop.add_signal_handler(signal.SIGINT, shutdown_callback)

        try:
            loop.run_until_complete(run_evaluator())
        finally:
            loop.close()
            self.logger.info(f"Evaluator {local_id}: Event loop closed.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run the TaskManager for evaluators with configurable scaling interval."
    )

    # General settings.
    parser.add_argument(
        "--log-dir",
        type=str,
        default=os.path.join(os.getcwd(), "logs"),
        help="Directory where logs will be stored. Defaults to './logs'."
    )
    parser.add_argument(
        "--config-path",
        type=str,
        default=os.path.join(os.getcwd(), "config.py"),
        help="Path to the configuration file. Defaults to './config.py'."
    )
    parser.add_argument(
        "--sandbox_base_path",
        type=str,
        default=os.path.join(os.getcwd(), "sandbox"),
        help="Path to the sandbox directory. Defaults to './sandbox'."
    )

    # Resource-related arguments.
    parser.add_argument(
        "--check_interval",
        type=int,
        default=120,
        help="Time interval between scaling checks (default: 120s)."
    )

    parser.add_argument(
        "--no-dynamic-scaling",
        action="store_true",
        help="Disable dynamic scaling (enabled by default)."
    )

    parser.add_argument(
        "--max_evaluators",
        type=int,
        default=1000,
        help="Maximum evaluators (default: 1000)."
    )

    args = parser.parse_args()

    enable_dynamic_scaling = not args.no_dynamic_scaling

    async def main():
        config = load_config(args.config_path)
        # Process target_solutions JSON
        try:
            if isinstance(config.prompt.target_solutions, str):
                target_solutions = json.loads(config.prompt.target_solutions)
                target_solutions = {ast.literal_eval(k): v for k, v in target_solutions.items()}
            else:
                target_solutions = config.prompt.target_solutions
        except json.JSONDecodeError:
            raise ValueError("Invalid JSON format for --target_solutions.")

        spec_path = config.prompt.spec_path
        try:
            with open(spec_path, 'r') as file:
                specification = file.read()
            if not specification.strip():
                raise ValueError("Specification must be a non-empty string.")
        except FileNotFoundError:
            print(f"Error: Specification file not found at {spec_path}")
            sys.exit(1)
        except ValueError as e:
            print(f"Error in specification: {e}")
            sys.exit(1)

        if not (len(config.prompt.s_values) == len(config.prompt.start_n) == len(config.prompt.end_n)):
            raise ValueError("The number of elements in s_values, start_n, and end_n must match.")

        inputs = [
            (n, s)
            for s, start_n, end_n in zip(config.prompt.s_values, config.prompt.start_n, config.prompt.end_n)
            for n in range(start_n, end_n + 1)
        ]

        task_manager = TaskManager(
            specification=specification,
            inputs=inputs,
            config=config,
            log_dir=args.log_dir,
            target_solutions=target_solutions
        )
        task = asyncio.create_task(task_manager.main_task(enable_scaling=enable_dynamic_scaling))
        await task

    try:
        asyncio.run(main())
    except Exception as e:
        print(f"Error in asyncio.run(main()): {e}")
