import time
import traceback
from dataclasses import dataclass
from enum import IntEnum, auto
from queue import Empty
from typing import Any, Tuple

import os, copy, uuid, multiprocessing as mp

import numpy as np
import torch

from experiments.grid_runs_utils import prepare_grid_pairs_to_iterate, get_parallel_settings, \
    build_and_train_model, get_axes_from_grid_config, GridConstants
from mqar import MqarDimensions
from utils.common import set_seed


# ---------- classes ----------

@dataclass
class TaskStatus(IntEnum):
    UNINITIALIZED = auto()
    QUEUED = auto()
    RUNNING = auto()
    NOT_CONVERGED = auto()
    COMPLETED = auto()
    TERMINATE_REQUESTED = auto()
    TERMINATED = auto()
    FAILED = auto()
    SKIPPED = auto()


@dataclass(frozen=True)
class GridPoint:
    name: str
    value: int

    def as_dict(self) -> dict[str, int]:
        return {self.name: self.value}

    def __str__(self) -> str:
        return f"{self.name}={str(self.value).zfill(4)}"


class Task:

    def __init__(
            self,
            x: GridPoint,
            y: GridPoint,
            constants: GridConstants,
            seed: int,
    ):

        self.id: str = str(uuid.uuid4())
        self.dims: MqarDimensions = MqarDimensions(**x.as_dict(), **y.as_dict(), **constants.as_dict())
        self.seed: int = seed
        self.status: int = TaskStatus.UNINITIALIZED
        self.result: dict[str, Any] = {}

        self.name = f"{str(x)}, {str(y)}, {seed=}"


class Worker:

    def __init__(self, device: str, num_cpu_threads: int):
        self.id = uuid.uuid4()
        self.device = device
        self.num_cpu_threads = num_cpu_threads
        self.name = f"{self.device}-0x{self.id.hex[:4]}"


# ---------- parent: process scheduler ----------

def execute_grid_run(
        grid_run_name: str,
        run_config: dict[str, Any],
) -> dict[str, Any]:
    """
    Process-parallel grid runner using multiprocessing 'spawn'.
    - Multi-GPU, with per-device concurrency.
    - Each worker builds its own dataloaders.
    """

    parallel_config = run_config["parallel"]
    grid_config = run_config["grid"]
    grid_options = run_config["grid_options"]
    parent_seed = run_config['runtime']['seed']

    set_seed(parent_seed, verbose=True)

    # wandb
    run_config['wandb']['project_name'] = grid_run_name

    # -------- parallel settings --------

    used_devices, num_processes_per_device, num_cpu_threads_per_process = get_parallel_settings(run_config)
    num_devices = len(used_devices)

    # build device tokens with per-device concurrency
    workers_list: list[Worker] = []
    for device in used_devices:
        for _ in range(max(1, num_processes_per_device)):
            workers_list.append(Worker(device=device, num_cpu_threads=num_cpu_threads_per_process))
    num_workers = len(workers_list)
    num_threads = int(num_workers * num_cpu_threads_per_process)

    is_parallel_enabled = run_config.get("parallel", None) is not None
    is_parallel_enabled_str = "parallelism enabled" if is_parallel_enabled else "parallelism disabled"

    print(is_parallel_enabled_str)
    print(f"using {num_devices} devices")
    print(f"using {num_workers} workers ({num_processes_per_device} per device)")
    print(f"using {num_threads} threads ({num_cpu_threads_per_process} per worker process)")

    if num_cpu_threads_per_process > 1:
        assert not run_config["wandb"].get("activate", False), "wandb is not supported with multi-threading"
        os.environ.setdefault("WANDB_MODE", "disabled")

    # -------- multiprocessing setup --------

    # create task and result queues
    start_method = parallel_config.get("mp_start_method", "spawn")
    ctx = mp.get_context(start_method)
    task_queue: mp.Queue = ctx.Queue()
    result_queue: mp.Queue = ctx.Queue()

    # seeds
    seeds_config = run_config.get('best_of_n_seeds', {})
    n_seeds = seeds_config.get('n_seeds', 1)
    start_seed = seeds_config.get('start_seed', 0)
    seeds_to_run = range(start_seed, start_seed + n_seeds + 1)

    # set up grid/lines pairs
    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)
    xy_pairs, dims = prepare_grid_pairs_to_iterate(grid_axes, grid_constants, grid_options)
    grid_results_dict: dict[str, dict] = {}

    # optionally shuffle grid iteration order
    if grid_options.get('shuffle', False):
        rng = np.random.default_rng(parent_seed)  # reproducible
        xy_pairs = rng.permutation(xy_pairs).tolist()

    # enqueue tasks
    tasks = []
    for seed in seeds_to_run:
        tasks_per_seed = []
        for (x, y) in xy_pairs:
            x_pt = GridPoint(name=grid_axes.x.name, value=int(x))
            y_pt = GridPoint(name=grid_axes.y.name, value=int(y))
            task = Task(x=x_pt, y=y_pt, constants=grid_constants, seed=seed)
            # optionally scale N with D:
            if grid_config.get('scale_N_with_D', False):
                task.dims.N = int(task.dims.D / float(grid_config['D_to_N_ratio']))
            tasks_per_seed.append(task)
        tasks += tasks_per_seed
    tasks_by_id: dict[str, Task] = {t.id: t for t in tasks}

    # enqueue initial tasks (no sentinels yet; we may add retries later)
    for t in tasks:
        task_queue.put(t)
    n_enqueued = len(tasks)

    # launch workers
    procs: list[mp.Process] = []

    for i, worker in enumerate(workers_list):

        # print(f"launching worker {i}: {worker.name}")

        process_args = (worker, task_queue, result_queue, run_config)

        p = ctx.Process(target=_run_worker, args=process_args)
        p.daemon = False
        p.start()

        procs.append(p)

    # collect results; manage retries here
    received = 0

    while received < n_enqueued:

        # block until a result
        try:
            (task_id, status_int, result) = result_queue.get()
            received += 1  # count every finished task

        except Empty:
            # if any worker died, surface it instead of hanging
            for i, (p, w) in enumerate(zip(procs, workers_list)):
                if p.exitcode not in (0, None):
                    raise RuntimeError(f"worker {i} ({w.name}) exited with error: code {p.exitcode}")
            continue

        # get the task
        finished_task = tasks_by_id[task_id]
        print(f"\n{finished_task.name}: completed with status {TaskStatus(status_int).name}", flush=True)

    # all tasks (including retries) accounted for; now release workers
    for _ in range(num_workers):
        task_queue.put(None)  # sentinel per worker

    # join children
    for i, (p, w) in enumerate(zip(procs, workers_list)):
        p.join()
        if p.exitcode not in (0, None):
            raise RuntimeError(f"worker {i} ({w.name}) exited with error: code {p.exitcode}")

    return grid_results_dict


def _should_skip_grid_point(dims: MqarDimensions) -> bool:

    # mismatch with model dims
    if dims.N > dims.D:
        return True

    # mismatch with MQAR dims (zoology)
    if dims.L >= dims.V:
        return True
    if dims.L < (4 * dims.N_facts):
        return True

    return False



def _run_worker(
    worker: Worker,
    task_queue: mp.Queue,
    result_queue: mp.Queue,
    base_run_config: dict,
):
    """
    One process per worker.
    Builds its own dataloaders and runs tasks pulled from task_queue.
    Sends TaskResult objects back via result_queue.
    """

    print(f"worker {worker.name}: started")
    time.sleep(1)  # optional (to make print less mixed up)

    # keep each process lean on CPU threads (avoid oversubscription)
    num_cpu_threads = max(1, worker.num_cpu_threads)

    # set num_cpu_threads
    os.environ.setdefault("OMP_NUM_THREADS", str(num_cpu_threads))
    os.environ.setdefault("MKL_NUM_THREADS", str(num_cpu_threads))
    os.environ.setdefault("NUMEXPR_NUM_THREADS", str(num_cpu_threads))
    torch.set_num_threads(num_cpu_threads)

    # runtime device and unique run name
    local_run_config = copy.deepcopy(base_run_config)
    local_run_config.setdefault('runtime', {})
    local_run_config['runtime']['device'] = worker.device  # always 'cpu' or 'cuda:i'
    local_run_config.setdefault('wandb', {})

    # set device in this process
    if worker.device.startswith("cuda"):
        torch.cuda.set_device(int(worker.device.split(":")[1]))

    # avoid d_state > d_model, since SSM hidden dim larger than input embedding dim makes no sense
    grid_config = base_run_config["grid"]
    grid_options = base_run_config["grid_options"]

    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)

    while True:

        # get a task from the queue
        task = task_queue.get()

        # end of queue (reached sentinel)
        if task is None:
            break

        print(f"worker {worker.name}: received task {task.name}")

        def _exit_with_status(task_status: TaskStatus, result=None):
            print(f"worker {worker.name}: exited with status {task_status.name}")
            if result == None:
                result = {}
            output = (task.id, task_status.value, result)
            result_queue.put(output)

        # should we skip?
        if _should_skip_grid_point(dims=task.dims):
            _exit_with_status(TaskStatus.SKIPPED)
            print(f"task {task.name} skipped")
            continue

        # we should run task

        # MOST IMPORTANT! set task seed
        local_run_config['runtime']['seed'] = task.seed  # to be used inside build_and_train_model
        set_seed(task.seed, verbose=True)  # just to be on the safe side

        # build, train and save model
        try:
            print(f"\nstarting task {task.name}")
            run_name = f"{task.name} @ {worker.name}"
            local_run_config['wandb']['run_name'] = run_name
            run_result = build_and_train_model(
                dims=task.dims,
                run_config=local_run_config,
                run_name=run_name,
            )

            _exit_with_status(TaskStatus.COMPLETED, result=run_result)

        except Exception as e:
            print(f"\n\nworker {worker.name} encountered an error:\n\n")
            traceback.print_exc()
            _exit_with_status(TaskStatus.FAILED)
            raise e


