import asyncio
import logging
from logging import FileHandler
import json
import aio_pika
from yarl import URL
import os
import signal
import sys
import argparse
import glob
import shutil
import datetime
import importlib.util
from multiprocessing import Process, current_process
from typing import Sequence, Any
import ast

# FundCC-specific imports
from fundcc import programs_database, sampler, code_manipulation, evaluator, gpt
from fundcc.scaling_utils import ResourceManager

# Prevent tokenizer parallelism issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def load_config(config_path):
    """
    Dynamically load a configuration module from a specified file path.
    """
    if not os.path.isfile(config_path):
        raise FileNotFoundError(f"Configuration file not found at {config_path}")
    spec = importlib.util.spec_from_file_location("config", config_path)
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)
    if not hasattr(config_module, "Config"):
        raise ValueError(f"The configuration file at {config_path} must define a 'Config' class.")
    return config_module.Config()


def backup_python_files(src, dest, exclude_dirs=[]):
    """
    Recursively copies all Python files in `src` to `dest`.
    Directories listed in `exclude_dirs` are skipped.
    """
    for file_path in glob.glob(os.path.join(src, '**', '*.py'), recursive=True):
        if "/code_backup/" in file_path:
            continue
        if any(file_path.startswith(ed) for ed in exclude_dirs):
            continue
        new_path = os.path.join(dest, os.path.relpath(file_path, start=src))
        os.makedirs(os.path.dirname(new_path), exist_ok=True)
        shutil.copy(file_path, new_path)


class TaskManager:
    def __init__(self, specification: str, inputs: Sequence[Any], config, log_dir, target_solutions):
        self.template = code_manipulation.text_to_program(specification)
        self.template_pdb = code_manipulation.text_to_program(specification, remove_classes=True)
        self.inputs = inputs
        self.config = config
        self.logger = self.initialize_logger(log_dir)
        self.evaluator_processes = []
        self.database_processes = []
        self.sampler_processes = []
        self.tasks = []
        self.channels = []
        self.queues = []
        self.connection = None
        if self.config.prompt.gpt:
            self.resource_manager = ResourceManager(log_dir=log_dir, cpu_only=True)
        else:
            self.resource_manager = ResourceManager(log_dir=log_dir)
        self.process_to_device_map = {}
        self.target_solutions = target_solutions

    def initialize_logger(self, log_dir):
        logger = logging.getLogger('main_logger')
        logger.setLevel(logging.INFO)
        os.makedirs(log_dir, exist_ok=True)
        log_file_path = os.path.join(log_dir, 'fundcc.log')
        handler = FileHandler(log_file_path, mode='w')
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.propagate = False
        return logger

    async def publish_initial_program_with_retry(self, amqp_url, initial_program_data, max_retries=5, delay=5):
        attempt = 0
        while attempt < max_retries:
            try:
                sampler_connection = await aio_pika.connect_robust(amqp_url, timeout=300)
                sampler_channel = await sampler_connection.channel()
                # Ensure evaluator_queue is declared
                await sampler_channel.declare_queue(
                    "evaluator_queue", durable=False, auto_delete=True,
                    #arguments={'x-consumer-timeout': 360000000}
                )
                await sampler_channel.default_exchange.publish(
                    aio_pika.Message(body=initial_program_data.encode()),
                    routing_key='evaluator_queue'
                )
                self.logger.info("Published initial program")
                await sampler_channel.close()
                await sampler_connection.close()
                return
            except Exception as e:
                attempt += 1
                self.logger.error(f"Attempt {attempt} failed to publish initial program: {e}")
                if attempt < max_retries:
                    self.logger.info(f"Retrying in {delay} seconds...")
                    await asyncio.sleep(delay)
                else:
                    self.logger.error("Max retries reached. Failed to publish initial program.")
                    raise e

    async def log_tasks(self):
        """
        Periodically logs details about active asyncio tasks.
        """
        while True:
            tasks = asyncio.all_tasks()
            self.logger.debug(f"Currently {len(tasks)} tasks running:")
            for task in tasks:
                coro_name = task.get_coro().__name__ if task.get_coro() else "Unknown"
                self.logger.debug(f"Task: {task.get_name()}, Function: {coro_name}, Status: {task._state}")
                if task._state == "PENDING":
                    for frame in task.get_stack():
                        self.logger.debug(f"Pending Task Frame: {frame}")
            await asyncio.sleep(60)

    async def main_task(self, save_checkpoints_path, enable_scaling=True, checkpoint_file=None):
        # Try connecting to RabbitMQ with vhost; fall back if necessary.
        try:
            amqp_url = URL(
                f'amqp://{self.config.rabbitmq.username}:{self.config.rabbitmq.password}@'
                f'{self.config.rabbitmq.host}:{self.config.rabbitmq.port}/{self.config.rabbitmq.vhost}'
            ).update_query(heartbeat=300)
            connection = await aio_pika.connect_robust(amqp_url)
        except Exception as e:
            self.logger.info("No vhost configured, connecting without.")
            amqp_url = URL(
                f'amqp://{self.config.rabbitmq.username}:{self.config.rabbitmq.password}@'
                f'{self.config.rabbitmq.host}:{self.config.rabbitmq.port}/'
            ).update_query(heartbeat=300)
            connection = await aio_pika.connect_robust(amqp_url)

        pid = os.getpid()
        self.logger.info(f"Main_task is running in process with PID: {pid}")

        function_to_evolve = 'priority'
        if checkpoint_file is None:
            initial_program_data = json.dumps({
                "sample": self.template.get_function(function_to_evolve).body,
                "island_id": None,
                "version_generated": None,
                "expected_version": 0
            })

        try:
            # Create connections and declare queues.
            sampler_connection = await aio_pika.connect_robust(amqp_url, timeout=300)
            self.sampler_channel = await sampler_connection.channel()

            database_connection = await aio_pika.connect_robust(amqp_url, timeout=300)
            self.database_channel = await database_connection.channel()

            evaluator_queue = await self.sampler_channel.declare_queue(
                "evaluator_queue", durable=False, auto_delete=True,
                #arguments={'x-consumer-timeout': 360000000}
            )
            sampler_queue = await self.sampler_channel.declare_queue(
                "sampler_queue", durable=False, auto_delete=True,
                #arguments={'x-consumer-timeout': 360000000}
            )
            database_queue = await self.database_channel.declare_queue(
                "database_queue", durable=False, auto_delete=True,
                #arguments={'x-consumer-timeout': 360000000}
            )

            # Create the database instance.
            try:
                database = programs_database.ProgramsDatabase(
                    database_connection, self.database_channel, database_queue,
                    sampler_queue, evaluator_queue, self.config.programs_database, self.config.prompt,
                    self.template_pdb, function_to_evolve, checkpoint_file, save_checkpoints_path, self.target_solutions
                )
                database_task = asyncio.create_task(database.consume_and_process())
            except Exception as e:
                self.logger.error(f"Exception in database: {e}")

            checkpoint_task = asyncio.create_task(database.periodic_checkpoint())

            # Start consumer processes.
            try:
                self.start_initial_processes(function_to_evolve, amqp_url, checkpoint_file)
                self.logger.info("Initial processes started successfully.")
            except Exception as e:
                self.logger.error(f"Failed to start initial processes: {e}")

            # Publish the initial program with retry logic.
            while True:
                sampler_queue = await self.sampler_channel.declare_queue("sampler_queue", passive=True)
                consumer_count = sampler_queue.declaration_result.consumer_count
                self.logger.info(f"Consumer count is {consumer_count} while config.num_samplers is {self.config.num_samplers}")
                if consumer_count > self.config.num_samplers - 1 and checkpoint_file is None:
                    await self.publish_initial_program_with_retry(amqp_url, initial_program_data)
                    break
                elif consumer_count > self.config.num_samplers - 1:
                    await database.get_prompt()
                    self.logger.info(f"Loading from checkpoint: {checkpoint_file}")
                    break
                else:
                    self.logger.info("No consumers yet on sampler_queue. Retrying in 10 seconds...")
                    await asyncio.sleep(10)

            resource_logging_task = asyncio.create_task(
                self.resource_manager.log_resource_stats_periodically(interval=60)
            )
            self.tasks = [database_task, checkpoint_task, resource_logging_task]

            if enable_scaling:
                try:
                    scaling_task = asyncio.create_task(
                        self.resource_manager.run_scaling_loop(
                            evaluator_queue=evaluator_queue,
                            sampler_queue=sampler_queue,
                            evaluator_processes=self.evaluator_processes,
                            sampler_processes=self.sampler_processes,
                            evaluator_function=self.evaluator_process,
                            sampler_function=self.sampler_process,
                            evaluator_args=(self.template, self.inputs, amqp_url),
                            sampler_args=(amqp_url,),
                            max_evaluators=args.max_evaluators,
                            max_samplers=args.max_samplers,
                            check_interval=args.check_interval,
                        )
                    )
                    self.tasks.append(scaling_task)
                except Exception as e:
                    self.logger.error(f"Error enabling scaling: {e}")

            self.channels = [self.database_channel, self.sampler_channel]
            self.queues = ["database_queue", "sampler_queue", "evaluator_queue"]

            await asyncio.gather(*self.tasks)
        except Exception as e:
            self.logger.error(f"Exception occurred in main_task: {e}")

    def start_initial_processes(self, function_to_evolve, amqp_url, checkpoint_file):
        amqp_url = str(amqp_url)
        # Start sampler processes.
        if self.config.prompt.gpt:
            self.logger.info("GPT mode enabled. Starting sampler processes without GPU assignment.")
            for i in range(self.config.num_samplers):
                device = None
                try:
                    proc = Process(target=self.sampler_process, args=(amqp_url, device), name=f"Sampler-{i}")
                    proc.start()
                    self.sampler_processes.append(proc)
                    self.process_to_device_map[proc.pid] = device
                    self.logger.debug(f"Started Sampler Process {i} (GPT mode) with PID: {proc.pid}")
                except Exception as e:
                    self.logger.error(f"Error starting sampler {i}: {e}")
        else:
            assigned_gpus = set()
            for i in range(self.config.num_samplers):
                try:
                    assignment = self.resource_manager.assign_gpu_device(assigned_gpus=assigned_gpus)
                except Exception as e:
                    self.logger.error(f"Cannot start sampler {i}: No suitable GPU available ({e}).")
                    assignment = None

                if assignment is None:
                    self.logger.error("No suitable GPU available for sampler. Skipping.")
                    continue
                host_gpu, device = assignment
                assigned_gpus.add(device)
                self.logger.info(f"Assigning sampler {i} to GPU {device} (host GPU: {host_gpu})")
                try:
                    proc = Process(target=self.sampler_process, args=(amqp_url, device), name=f"Sampler-{i}")
                    proc.start()
                    self.sampler_processes.append(proc)
                    self.process_to_device_map[proc.pid] = device
                    self.logger.debug(f"Process-to-Device Map: {self.process_to_device_map}")
                except Exception as e:
                    self.logger.error(f"Failed to start sampler {i}: {e}")
                    continue

        # Start evaluator processes.
        for i in range(self.config.num_evaluators):
            proc = Process(target=self.evaluator_process,
                           args=(self.template, self.inputs, amqp_url),
                           name=f"Evaluator-{i}")
            proc.start()
            self.logger.debug(f"Started Evaluator Process {i} with PID: {proc.pid}")
            self.evaluator_processes.append(proc)

    def sampler_process(self, amqp_url, device=None):
        local_id = current_process().pid
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        connection = None
        channel = None
        sampler_task = None

        async def run_sampler():
            nonlocal connection, channel, sampler_task
            try:
                self.logger.debug(f"Sampler {local_id}: Connecting to RabbitMQ.")
                connection = await aio_pika.connect_robust(amqp_url, timeout=300)
                channel = await connection.channel()
                self.logger.debug(f"Sampler {local_id}: Channel established.")
                sampler_queue = await channel.declare_queue(
                    "sampler_queue", durable=False, auto_delete=True,
                    #arguments={'x-consumer-timeout': 360000000}
                )
                evaluator_queue = await channel.declare_queue(
                    "evaluator_queue", durable=False, auto_delete=True,
                    #arguments={'x-consumer-timeout': 360000000}
                )
                try:
                    if self.config.prompt.gpt:
                        sampler_instance = gpt.Sampler(connection, channel, sampler_queue, evaluator_queue, self.config.sampler, self.config.prompt, local_id)
                        self.logger.debug(f"Sampler {local_id}: Initialized GPT Sampler instance.")
                    else:
                        sampler_instance = sampler.Sampler(connection, channel, sampler_queue, evaluator_queue, self.config.sampler, device, local_id)
                        self.logger.debug(f"Sampler {local_id}: Initialized Sampler instance.")
                except Exception as e:
                    self.logger.error(f"Could not start Sampler instance: {e}")
                    return
                sampler_task = asyncio.create_task(sampler_instance.consume_and_process())
                await sampler_task
            except asyncio.CancelledError:
                self.logger.info(f"Sampler {local_id}: Process was cancelled.")
            except Exception as e:
                self.logger.error(f"Sampler {local_id} encountered an error: {e}")
            finally:
                if channel:
                    await channel.close()
                if connection:
                    await connection.close()
                self.logger.debug(f"Sampler {local_id}: Connection closed.")

        def shutdown_callback():
            self.logger.info(f"Sampler {local_id}: Sending shutdown exception...")
            for task in asyncio.all_tasks():
                task.cancel()
            asyncio.create_task(run_sampler())  # This will raise a CancelledError in tasks.

        loop.add_signal_handler(signal.SIGTERM, shutdown_callback)
        loop.add_signal_handler(signal.SIGINT, shutdown_callback)

        try:
            loop.run_until_complete(run_sampler())
        finally:
            loop.close()
            self.logger.info(f"Sampler {local_id}: Event loop closed.")

    def evaluator_process(self, template, inputs, amqp_url):
        local_id = current_process().pid
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        connection = None
        channel = None
        evaluator_task = None

        async def run_evaluator():
            try:
                self.logger.info(f"Evaluator {local_id}: Connecting to RabbitMQ.")
                connection = await aio_pika.connect_robust(amqp_url, timeout=300)
                channel = await connection.channel()
                
                evaluator_queue = await channel.declare_queue("evaluator_queue", durable=False, auto_delete=True)
                database_queue = await channel.declare_queue("database_queue", durable=False, auto_delete=True)

                evaluator_instance = evaluator.Evaluator(
                    connection, channel, evaluator_queue, database_queue,
                    template, 'priority', 'evaluate', inputs, timeout_seconds=self.config.evaluator.timeout, local_id=local_id, sandbox_base_path=args.sandbox_base_path, target_solutions=self.target_solutions
                )

                await evaluator_instance.consume_and_process()

            except aio_pika.exceptions.AMQPConnectionError as e:
                self.logger.error(f"Evaluator {local_id}: Connection lost, attempting to reconnect. Error: {e}")
                await asyncio.sleep(5)  # Wait and retry
                return await run_evaluator()  # Restart evaluator

            except asyncio.CancelledError:
                self.logger.warning(f"Evaluator {local_id}: Task was cancelled. Ignoring...")
            
            except Exception as e:
                self.logger.error(f"Evaluator {local_id}: Unhandled error: {e}")

            finally:
                # Ensure cleanup happens safely
                try:
                    if channel and not channel.is_closed:
                        self.logger.info(f"Evaluator {local_id}: Closing channel...")
                        await asyncio.shield(channel.close())
                except Exception as e:
                    self.logger.warning(f"Evaluator {local_id}: Failed to close channel: {e}")

                try:
                    if connection and not connection.is_closed:
                        self.logger.info(f"Evaluator {local_id}: Closing connection...")
                        await asyncio.shield(connection.close())
                except Exception as e:
                    self.logger.warning(f"Evaluator {local_id}: Failed to close connection: {e}")

                self.logger.info(f"Evaluator {local_id}: Cleanup done, exiting.")
                loop.stop()

        def shutdown_callback():
            self.logger.info(f"Evaluator {local_id}: Sending shutdown exception...")
            for task in asyncio.all_tasks():
                task.cancel()
            asyncio.create_task(run_evaluator())  # This will raise a CancelledError in tasks.

        loop.add_signal_handler(signal.SIGTERM, shutdown_callback)
        loop.add_signal_handler(signal.SIGINT, shutdown_callback)

        try:
            loop.run_until_complete(run_evaluator())
        finally:
            loop.close()
            self.logger.info(f"Evaluator {local_id}: Event loop closed.")


if __name__ == "__main__":
    base_dir = os.path.dirname(os.path.abspath(__file__))
    parser = argparse.ArgumentParser(description="Run FunSearch experiment.")

    # General settings
    parser.add_argument("--backup", action="store_true",
                        help="Enable backup of Python files before running the task.")
    parser.add_argument("--save_checkpoints_path", type=str,
                        default=os.path.join(os.getcwd(), "Checkpoints"),
                        help="Path where checkpoints should be written.")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Path to the checkpoint file.")
    parser.add_argument("--config-path", type=str,
                        default=os.path.join(os.getcwd(), "config.py"),
                        help="Path to the configuration file.")
    parser.add_argument("--log-dir", type=str,
                        default=os.path.join(os.getcwd(), "logs"),
                        help="Directory where logs will be stored.")
    parser.add_argument("--sandbox_base_path", type=str,
                        default=os.path.join(os.getcwd(), "sandbox"),
                        help="Path to the sandbox directory.")

    # Resource related arguments
    parser.add_argument("--no-dynamic-scaling", action="store_true",
                        help="Disable dynamic scaling (enabled by default).")
    parser.add_argument("--check_interval", type=int, default=120,
                        help="Interval (in seconds) between scaling checks.")
    parser.add_argument("--max_evaluators", type=int, default=1000,
                        help="Maximum evaluators to scale up to.")
    parser.add_argument("--max_samplers", type=int, default=1000,
                        help="Maximum samplers to scale up to.")

    args = parser.parse_args()

    enable_dynamic_scaling = not args.no_dynamic_scaling

    if args.backup:
        src_dir = os.getcwd()
        backup_base_dir = '/mnt/hdd_pool/userdata/'
        os.makedirs(backup_base_dir, exist_ok=True)
        backup_dir = os.path.join(backup_base_dir, datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
        os.makedirs(backup_dir, exist_ok=True)
        backup_python_files(src=src_dir, dest=backup_dir)
        print(f"Backup completed. Python files saved to: {backup_dir}")

    # Configure time and memory logger
    time_memory_logger = logging.getLogger('time_memory_logger')
    time_memory_logger.setLevel(logging.INFO)
    os.makedirs(args.log_dir, exist_ok=True)
    time_memory_log_file = os.path.join(args.log_dir, 'time_memory.log')
    file_handler = FileHandler(time_memory_log_file, mode='w')
    file_handler.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    time_memory_logger.addHandler(file_handler)

    async def main():
        config = load_config(args.config_path)
        # Process target_solutions JSON
        try:
            if isinstance(config.prompt.target_solutions, str):
                target_solutions = json.loads(config.prompt.target_solutions)
                target_solutions = {ast.literal_eval(k): v for k, v in target_solutions.items()}
            else:
                target_solutions = config.prompt.target_solutions
        except json.JSONDecodeError:
            raise ValueError("Invalid JSON format for --target_solutions.")
        spec_path = config.prompt.spec_path
        try:
            with open(spec_path, 'r') as file:
                specification = file.read()
            if not specification.strip():
                raise ValueError("Specification must be a non-empty string.")
        except FileNotFoundError:
            print(f"Error: Specification file not found at {spec_path}")
            sys.exit(1)
        except ValueError as e:
            print(f"Error in specification: {e}")
            sys.exit(1)

        if not (len(config.prompt.s_values) == len(config.prompt.start_n) == len(config.prompt.end_n)):
            raise ValueError("The lengths of s_values, start_n, and end_n must match.")

        inputs = [(n, s)
                  for s, start_n, end_n in zip(config.prompt.s_values, config.prompt.start_n, config.prompt.end_n)
                  for n in range(start_n, end_n + 1)]
 
        task_manager = TaskManager(specification=specification, inputs=inputs, config=config,
                                   log_dir=args.log_dir, target_solutions=target_solutions)
        task = asyncio.create_task(
            task_manager.main_task(
                save_checkpoints_path=args.save_checkpoints_path,
                enable_scaling=enable_dynamic_scaling,
                checkpoint_file=args.checkpoint
            )
        )
        await task

    try:
        asyncio.run(main())
    except Exception as e:
        print(f"Error in asyncio.run(main()): {e}")
