import os
import psutil
import pynvml
import logging
import asyncio
import multiprocessing as mp
from logging import FileHandler
import socket
import statistics

def sampler_process_entry(target_function, args, gpu_device):
    """Wrapper for sampler processes."""
    target_function(*args, gpu_device)

class ResourceManager:
    def __init__(self, log_dir=None, resource_logger=None, cpu_only=False):
        """Initialize the ResourceManager asynchronously."""
        self.hostname = socket.gethostname()
        self.cpu_only = cpu_only
        self.process_to_device_map = {}
        if resource_logger is None:
            if log_dir is None:
                raise ValueError("Either resource_logger or log_dir must be provided")
            self.resource_logger = self._initialize_resource_logger(log_dir)
        else:
            self.resource_logger = resource_logger
        if not self.cpu_only:
            try:
                self._initialize_nvml()
            except Exception as e:
                self.resource_logger.warning(f"Failed to initialize NVML: {e}")
                self.cpu_only = True
                self.resource_logger.info("Switching to CPU-only mode.")

    def _initialize_nvml(self):
        """Initialize NVML for GPU monitoring."""
        pynvml.nvmlInit()

    def _initialize_resource_logger(self, log_dir):
        """Sets up a file-based logger."""
        pid = os.getpid()
        log_file_name = f"resources_{self.hostname}_pid{pid}.log"
        log_file_path = os.path.join(log_dir, log_file_name)
        logger = logging.getLogger(f'resource_logger_{pid}')
        logger.setLevel(logging.DEBUG)
        os.makedirs(log_dir, exist_ok=True)
        handler = FileHandler(log_file_path)
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.propagate = False
        logger.info(f"Resource logger initialized for PID {pid}. Log file: {log_file_path}")
        return logger

    async def log_resource_stats_periodically(self, interval=60, sample_duration=10, sample_interval=1):
        """
        Logs system resource usage periodically, averaging values over `sample_duration` seconds.
        - interval: Time between log entries (seconds).
        - sample_duration: Time window over which to collect samples.
        - sample_interval: Time between samples within the window.
        """
        while True:
            try:
                num_samples = max(1, sample_duration // sample_interval)
                cpu_samples = []
                io_wait_samples = []
                load_samples = []
                d_state_samples = []
                disk_read_samples = []
                disk_write_samples = []
                ctx_switch_samples = []
                mem_samples = []
                swap_samples = []
                if not self.cpu_only:
                    gpu_samples = []
                for _ in range(num_samples):
                    cpu_samples.append(await self.async_get_cpu_usage())
                    io_wait_samples.append(await asyncio.to_thread(lambda: psutil.cpu_times_percent(interval=1).iowait))
                    load_avg = await asyncio.to_thread(os.getloadavg)
                    load_samples.append(load_avg[0])
                    ctx_switch_samples.append(await asyncio.to_thread(lambda: psutil.cpu_stats().ctx_switches))
                    disk_io = await asyncio.to_thread(psutil.disk_io_counters)
                    disk_read_samples.append(disk_io.read_bytes / 1e6)  # MB
                    disk_write_samples.append(disk_io.write_bytes / 1e6)
                    memory = await asyncio.to_thread(psutil.virtual_memory)
                    swap = await asyncio.to_thread(psutil.swap_memory)
                    mem_samples.append(memory.percent)
                    swap_samples.append(swap.percent)
                    d_state_samples.append(
                        await asyncio.to_thread(lambda: len([p for p in psutil.process_iter(['status']) if p.info['status'] == 'D']))
                    )
                    if not self.cpu_only:
                        try:
                            gpu_samples.append(await self.async_get_gpu_usage())
                        except Exception as e:
                            self.resource_logger.warning(f"GPU monitoring failed: {e}")
                    await asyncio.sleep(sample_interval)

                avg_cpu = statistics.mean(cpu_samples)
                avg_io_wait = statistics.mean(io_wait_samples)
                avg_load = statistics.mean(load_samples)
                avg_ctx_switch = statistics.mean(ctx_switch_samples)
                avg_disk_read = statistics.mean(disk_read_samples)
                avg_disk_write = statistics.mean(disk_write_samples)
                avg_mem = statistics.mean(mem_samples)
                avg_swap = statistics.mean(swap_samples)
                avg_d_state = statistics.mean(d_state_samples)
                log_message = (
                    f"Avg CPU: {avg_cpu:.2f}%, Load: {avg_load:.2f}, I/O Wait: {avg_io_wait:.2f}%, "
                    f"Ctx Switches: {avg_ctx_switch}, Disk Read/Write: {avg_disk_read:.2f}/{avg_disk_write:.2f} MB, "
                    f"Mem Usage: {avg_mem:.2f}%, Swap: {avg_swap:.2f}%, D-State Processes: {avg_d_state}"
                )
                if not self.cpu_only:
                    avg_gpu = statistics.mean(gpu_samples) if gpu_samples else 0
                    log_message += f", GPU Usage: {avg_gpu:.2f}%"
                self.resource_logger.info(log_message)
            except Exception as e:
                self.resource_logger.error(f"Error logging resource stats: {e}")
            await asyncio.sleep(interval)

    async def async_get_cpu_usage(self):
        """Retrieves CPU usage asynchronously."""
        return await asyncio.to_thread(psutil.cpu_percent, interval=1)

    async def async_get_gpu_usage(self):
        """Retrieves GPU utilization asynchronously."""
        if self.cpu_only:
            return 0
        try:
            handle = await asyncio.to_thread(pynvml.nvmlDeviceGetHandleByIndex, 0)
            utilization = await asyncio.to_thread(pynvml.nvmlDeviceGetUtilizationRates, handle)
            return utilization.gpu
        except Exception as e:
            self.resource_logger.warning(f"GPU monitoring failed: {e}")
            return 0

    async def run_scaling_loop(self, evaluator_queue=None, sampler_queue=None, evaluator_processes=None,
                               sampler_processes=None, evaluator_function=None, sampler_function=None,
                               evaluator_args=(), sampler_args=(), max_evaluators=10000, min_evaluators=1,
                               max_samplers=1000, min_samplers=1, check_interval=120):
        """Scales evaluator and sampler processes dynamically based on queue sizes and system resources."""
        self.resource_logger.info("Starting scaling loop")
        evaluator_processes = evaluator_processes or []
        sampler_processes = sampler_processes or []
        evaluator_args = evaluator_args or ()
        sampler_args = sampler_args or ()
        max_evaluators = max_evaluators if max_evaluators is not None else 0

        while True:
            try:
                evaluator_message_count = await self.get_queue_message_count(evaluator_queue) if evaluator_queue else 0
                sampler_message_count = await self.get_queue_message_count(sampler_queue) if sampler_queue else 0
                self.resource_logger.info(f"Message counts are {evaluator_message_count} and {sampler_message_count}")
                # Scale Evaluators
                evaluator_scaled = True
                if evaluator_queue and max_evaluators > 0:
                    can_scale_eval = await self.can_scale_evaluator()
                    if evaluator_message_count > 10 and len(evaluator_processes) < max_evaluators and can_scale_eval:
                        self.resource_logger.info(f"Scaling evaluators: queue has {evaluator_message_count} messages")
                        self.start_evaluator_process(evaluator_function, evaluator_args, evaluator_processes, "Evaluator")
                        evaluator_scaled = True
                    elif evaluator_message_count == 0 and len(evaluator_processes) > min_evaluators:
                        self.resource_logger.info("Zero messages in evaluator queue; terminating one evaluator process")
                        self.terminate_process(evaluator_processes, "Evaluator")
                        evaluator_scaled = True

                # Scale Samplers
                sampler_scaled = False
                if sampler_queue and max_samplers > 0:
                    assignment = await self.can_scale_up_samplers()
                    if self.cpu_only:
                        assignment = await self.can_scale_evaluator()
                    self.resource_logger.info(f"Sampler GPU assignment: {assignment}")
                    if sampler_message_count > 50 and len(sampler_processes) < max_samplers and assignment:
                        self.resource_logger.info(f"Scaling samplers: queue has {sampler_message_count} messages")
                        started = self.start_sampler_process(sampler_function, sampler_args, sampler_processes, "Sampler", assignment=assignment)
                        if not started:
                            self.resource_logger.info("No available GPU found. Skipping sampler scale-up.")
                        sampler_scaled = True
                    elif sampler_message_count == 0 and len(sampler_processes) > min_samplers:
                        self.resource_logger.info("Zero messages in sampler queue; terminating one sampler process")
                        self.terminate_process(sampler_processes, "Sampler")
                        sampler_scaled = True

                if not evaluator_scaled and not sampler_scaled:
                    self.resource_logger.info("No scaling action taken in this iteration.")
            except Exception as e:
                self.resource_logger.error(f"Scaling loop encountered an error: {e}")
            await asyncio.sleep(check_interval)

    def start_evaluator_process(self, target_function, args, processes, process_name):
        """Starts a new evaluator process."""
        proc = mp.Process(target=target_function, args=args, name=f"{process_name}-{len(processes)}")
        proc.start()
        processes.append(proc)
        self.resource_logger.info(f"Started {process_name} process (PID: {proc.pid})")

    async def get_smoothed_cpu_usage(self, duration=10, interval=1):
        """
        Asynchronously collect CPU usage samples over 'duration' seconds at 'interval' second intervals.
        """
        samples = []
        iterations = int(duration / interval)
        for _ in range(iterations):
            usage = await asyncio.to_thread(psutil.cpu_percent, interval, True)
            avg_sample = sum(usage) / len(usage) if usage else 0
            samples.append(avg_sample)
        return samples

    async def can_scale_up_samplers(self):
        """Returns a GPU assignment tuple if available; otherwise None."""
        if self.cpu_only:
            return None
        assignment = self.assign_gpu_device()
        return assignment

    async def can_scale_evaluator(self, required_cores=4, cpu_usage_threshold=99, normalized_load_threshold=0.99, duration=10, interval=1):
        """
        Returns True if averaged CPU usage and normalized load are below thresholds.
        """
        smoothed_usage = await self.get_smoothed_cpu_usage(duration, interval)
        avg_cpu_usage = sum(smoothed_usage) / len(smoothed_usage) if smoothed_usage else 0
        load_avg = await asyncio.to_thread(os.getloadavg)
        load_avg_1 = load_avg[0]
        available_cores = len(os.sched_getaffinity(0))
        normalized_load = load_avg_1 / available_cores if available_cores > 0 else load_avg_1
        self.resource_logger.info(
            f"{self.hostname}: Smoothed Avg CPU: {avg_cpu_usage:.2f}% | Normalized Load: {normalized_load:.2f}"
        )
        return (avg_cpu_usage < cpu_usage_threshold) and (normalized_load < normalized_load_threshold)

    def start_sampler_process(self, target_function, args, processes, process_name, assignment):
        if assignment is True:
            proc = mp.Process(
                target=sampler_process_entry,
                args=(target_function, args, None),
                name=f"{process_name}-{len(processes)}"
            )
            proc.start()
            processes.append(proc)
            self.resource_logger.info(f"Started {process_name} process (PID: {proc.pid}) in CPU-only mode.")
            self.process_to_device_map[proc.pid] = None
            return True
        elif assignment is not None:
            host_gpu, container_device = assignment
            proc = mp.Process(
                target=sampler_process_entry,
                args=(target_function, args, container_device),
                name=f"{process_name}-{len(processes)}"
            )
            proc.start()
            processes.append(proc)
            self.resource_logger.info(f"Started {process_name} process (PID: {proc.pid}) on GPU {container_device} (host GPU: {host_gpu})")
            self.process_to_device_map[proc.pid] = container_device
            return True
        else:
            return False

    def assign_gpu_device(self, min_free_memory_gib=34, max_utilization=50, assigned_gpus=None):
        """
        Assigns one or more GPUs with cumulative free memory >= min_free_memory_gib
        and each having utilization < max_utilization. Returns 'auto' if multiple GPUs used.
        """
        if self.cpu_only:
            return None
        try:
            visible_str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
            if visible_str:
                try:
                    visible_devices = [int(x.strip()) for x in visible_str.split(",") if x.strip()]
                except ValueError:
                    self.resource_logger.error("Failed to parse CUDA_VISIBLE_DEVICES.")
                    return None
            else:
                visible_devices = list(range(pynvml.nvmlDeviceGetCount()))

            id_to_container_index = {visible_devices[i]: i for i in range(len(visible_devices))}
            if assigned_gpus is None:
                assigned_gpus = set(self.process_to_device_map.values())

            available = []
            for host_gpu in visible_devices:
                container_device = f"cuda:{id_to_container_index[host_gpu]}"
                if container_device in assigned_gpus:
                    continue
                handle = pynvml.nvmlDeviceGetHandleByIndex(host_gpu)
                util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                mem = pynvml.nvmlDeviceGetMemoryInfo(handle)
                free_mem = mem.free / (1024 ** 3)
                if util.gpu < max_utilization:
                    available.append((host_gpu, container_device, free_mem, util.gpu))

            # Sort by highest free memory first
            available.sort(key=lambda x: -x[2])

            selected = []
            total_free = 0
            for gpu in available:
                selected.append(gpu)
                total_free += gpu[2]
                if total_free >= min_free_memory_gib:
                    break

            if total_free < min_free_memory_gib:
                return None

            if len(selected) == 1:
                host_gpu, container_device, free, util = selected[0]
                assigned_gpus.add(container_device)
                self.resource_logger.info(
                    f"Assigning GPU {host_gpu} ({container_device}): {free:.2f} GiB free, {util}% util"
                )
                return host_gpu, container_device
            else:
                host_gpus = [g[0] for g in selected]
                container_devices = [g[1] for g in selected]
                for dev in container_devices:
                    assigned_gpus.add(dev)
                self.resource_logger.info(
                    f"Assigning multiple GPUs {host_gpus} (containers {container_devices}) "
                    f"with total free {total_free:.2f} GiB"
                )
                return host_gpus, "cuda"  # "cuda" sets device map to auto in sampler for HF to spread across available GPUs

        except Exception as e:
            self.resource_logger.error(f"Error in assign_gpu_device: {e}")
            return None


    def terminate_process(self, processes, process_name, timeout=30):
        """Terminates a running process."""
        if processes:
            proc = processes.pop(0)
            proc.terminate()
            self.resource_logger.info(f"Sent SIGTERM to {process_name} process (PID: {proc.pid}). Waiting for termination...")
            proc.join(timeout)
            if proc.is_alive():
                self.resource_logger.warning(f"{process_name} process (PID: {proc.pid}) did not terminate in {timeout}s. Sending SIGKILL.")
                proc.kill()
                proc.join()
            self.resource_logger.info(f"Terminated {process_name} process (PID: {proc.pid})")

    async def get_queue_message_count(self, queue):
        """
        Retrieves the current number of messages in the queue.
        """
        if queue is None:
            return 0
        try:
            declared_queue = await queue.channel.declare_queue(queue.name, passive=True)
            return declared_queue.declaration_result.message_count
        except Exception as e:
            self.resource_logger.error(f"Error getting message count for queue '{queue.name}': {e}")
            return 0
