import json
from copy import copy
from dataclasses import dataclass, field
from typing import Mapping

import numpy as np
import torch
from datetime import datetime
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo

from config import get_model_from_config, get_dataset_config_by_split_from_dataset_config
from mqar import MqarDimensions
from mqar.dataloaders import get_mqar_dynamic_dataloaders_by_split
from mqar.train import run_train_loop
from utils.common import set_seed


@dataclass
class GridAxis:
    axis: np.ndarray
    name: str

@dataclass
class GridAxes:
    x: GridAxis
    y: GridAxis

@dataclass(frozen=True)
class GridConstants:
    data: Mapping[str, int]
    aliases: Mapping[str, str] = field(default_factory=lambda: {"N_facts": "Nf"})

    def as_dict(self) -> dict[str, int]:
        return dict(self.data)

    def _key(self, k: str) -> str:
        return self.aliases.get(k, k)

    def _field_name(self, k, v, safe=False) -> str:
        if isinstance(v, int):
            s = f"{self._key(k)}={int(v):d}"
        elif v is None:
            s = f"{self._key(k)}"
        else:
            raise TypeError
        if safe:
            s = s.replace("=", "")  # TODO maybe improve
        return s

    @property
    def name(self) -> str:
        return ", ".join(self._field_name(k, v) for k, v in self.data.items())

    @property
    def safe_name(self) -> str:
        return "_".join(self._field_name(k, v, safe=True) for k, v in self.data.items())


def initialize_grid_run(run_config: dict[str, Any]):

    model_config = run_config["model"]
    dataset_config = run_config["dataset"]
    grid_config = run_config["grid"]

    # time
    time_zone = ZoneInfo(run_config['io'].get('time_zone', None))
    time_now = datetime.now(time_zone)  # local time
    title_timestamp = time_now.strftime("%d.%m.%Y %H-%M-%S")
    dir_timestamp = time_now.strftime("%d_%m_%Y__%H_%M_%S")

    base_name = "MQAR"
    model_class = model_config['class']
    model_variant = model_config['variant']

    # grid constants
    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)

    # names
    grid_axes_name = f"{grid_axes.x.name}, {grid_axes.y.name}"
    grid_axes_safe_name = f"{grid_axes.x.name}_{grid_axes.y.name}"

    grid_run_name = \
        (f"{base_name} | "
         f"'{model_class}', '{model_variant}' | "
         f"{grid_axes_name}, {grid_constants.name} | "
         f"{title_timestamp}")

    grid_run_dir_name = \
        (f"{model_class}_{model_variant}__"
         f"{grid_axes_safe_name}_{grid_constants.safe_name}__"
         f"{dir_timestamp}")

    # dirs
    results_dir = Path(run_config["io"].get("results_dir"))
    grid_run_results_dir = results_dir / grid_run_dir_name
    run_config["io"]["run_results_dir"] = str(grid_run_results_dir)

    # make dirs
    results_dir.mkdir(parents=True, exist_ok=True)
    grid_run_results_dir.mkdir(parents=True, exist_ok=True)

    # save configs
    (grid_run_results_dir / "run_config.json").write_text(json.dumps(run_config, indent=2))

    print(f"\n\nStarting grid run:\n{grid_run_name}\n")
    print(f"Results are saved to: {grid_run_results_dir}\n")

    return grid_run_name


def prepare_grid_pairs_to_iterate(
        grid_axes: GridAxes,
        grid_constants: GridConstants,
        grid_options: dict[str, Any]
) -> tuple[list[tuple[int, int]], MqarDimensions]:

    dims = MqarDimensions(**grid_constants.as_dict())

    x_axis = grid_axes.x.axis
    y_axis = grid_axes.y.axis

    xy_pairs: list[tuple[int, int]] = []

    for x in x_axis:
        for y in y_axis:
            xy_pairs.append((x, y))

    return xy_pairs, dims


def get_parallel_settings(run_config: dict[str, Any]):

    parallel_config = run_config.get("parallel", None)
    is_wandb_activated = run_config["wandb"].get("activate", False)

    parallel_enabled = parallel_config is not None

    if parallel_enabled:
        num_processes_per_device = int(parallel_config.get("num_processes_per_device", 1))
        num_cpu_threads_per_process = int(parallel_config.get("num_cpu_threads_per_process", 1))
        if num_cpu_threads_per_process > 1:
            assert not is_wandb_activated, \
                (f"wandb is not supported with multi-threading; "
                 f"set num_cpu_threads_per_process to 1 (currently {num_cpu_threads_per_process})")
    else:
        num_processes_per_device = 1
        num_cpu_threads_per_process = 1

    # choose devices
    if parallel_enabled:
        req = parallel_config.get("devices_to_use")  # e.g. [0,1,2], ["cuda:0"], None, or ["cpu"]
        if not req:  # auto: all GPUs or CPU
            used_devices = (
                [_select_device(i) for i in range(torch.cuda.device_count())]
                if _select_device() != "cpu" else ["cpu"])
        else:
            used_devices = [_select_device(d) for d in req] or ["cpu"]
            used_devices = list(dict.fromkeys(used_devices))  # dedupe; drop if you want multiple CPU workers
    else:  # sequential
        used_devices = [_select_device(run_config.get("runtime", {}).get("device"))]

    return used_devices, num_processes_per_device, num_cpu_threads_per_process

def _select_device(token=None) -> str:
    """
    Normalize a device token to 'cpu' or 'cuda:<idx>'.
    Accepts: None, int, '3', 'cuda:3', 'cpu'.
    Never raises; falls back to 'cpu' if CUDA is unusable/out-of-range.
    """

    def cuda_ok() -> bool:
        try:
            return torch.cuda.is_available() and torch.cuda.device_count() > 0
        except Exception:
            return False

    if token is None:
        return "cuda:0" if cuda_ok() else "cpu"

    s = str(token).strip().lower()
    if s == "cpu":
        return "cpu"

    if s.isdigit():
        idx = int(s)
    elif s.startswith("cuda:"):
        try:
            idx = int(s.split(":", 1)[1])
        except Exception:
            return "cpu"
    else:
        return "cpu"

    if cuda_ok() and 0 <= idx < torch.cuda.device_count():
        return f"cuda:{idx}"
    return "cpu"


def get_axes_from_grid_config(grid_config: dict[str, Any]) -> tuple[GridAxes, GridConstants]:

    x_name = grid_config['x_axis']
    y_name = grid_config['y_axis']

    x_axis = eval(grid_config[x_name]).astype(int)
    y_axis = eval(grid_config[y_name]).astype(int)

    x = GridAxis(name=x_name, axis=x_axis)
    y = GridAxis(name=y_name, axis=y_axis)

    dim_names = ['V', 'L', 'N_facts', 'D', 'N']
    axes_names = [x_name, y_name]

    constants = {name: value for name, value in grid_config.items() if ((name in dim_names) and (name not in axes_names))}

    grid_constants = GridConstants(constants)
    grid_axes = GridAxes(x=x, y=y)

    return grid_axes, grid_constants


# backward compatibility
def _get_axes_from_grid_config(grid_config: dict[str, Any]):

    D_axis = eval(grid_config['D_axis'])
    N_axis = eval(grid_config['N_axis'])

    return D_axis, N_axis


def build_and_train_model(
        dims: MqarDimensions,
        run_config: dict[str, Any],
        run_name: str = None,
):
    set_seed(run_config['runtime']['seed'], verbose=True)

    model_config = run_config['model']

    # build model
    model = get_model_from_config(
        V=dims.V, D=dims.D, N=dims.N,
        model_class=model_config["class"],
        model_variant=model_config["variant"],
        dropout_rate=run_config["training"].get("dropout_rate", 0),
    )

    train_config = run_config['training']

    base_lr = train_config['optimizer']['learning_rate']
    base_bs = run_config['dataset']['batch_size']

    # optional: scale batch size and learning rate (larger for smaller models)
    if (scaling_config := run_config['training'].get('scaling', None)) is not None:

        debug_text = copy(run_name)

        if (scaled_lr_expr := scaling_config.get('learning_rate', None)) is not None:
            scaled_lr = eval(scaled_lr_expr)  # TODO - maybe base * eval(factor)
            train_config['optimizer']['learning_rate'] = scaled_lr
            debug_text += f"LR scaling: {base_lr} -> {scaled_lr} | "

        if (scaled_bs_expr := scaling_config.get('batch_size', None)) is not None:
            scaled_bs = eval(scaled_bs_expr)  # TODO - maybe base * eval(factor)
            run_config['dataset']['batch_size'] = scaled_bs
            debug_text += f"BS scaling: {base_bs} -> {scaled_bs} | "

        # print summary
        print(debug_text)

    # now build dataloaders (using optionally scaled batch size)
    dataloaders = _build_dataloaders(dims=dims, run_config=run_config)

    device = run_config['runtime']['device']

    # ...
    if device.startswith("cuda"):
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    torch.set_float32_matmul_precision("high")  # TEMP, TODO?

    # train model
    results = run_train_loop(
        model=model, dataloaders=dataloaders,
        run_config=run_config, debug_text=run_name,
    )

    # cleanup
    if device.startswith("cuda"):
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

    return results


def _build_dataloaders(dims: MqarDimensions, run_config: dict[str, Any]):

    dataloaders_seed = run_config['runtime']['seed']

    # build dataloaders *inside* the child (since they are dynamic)
    dataset_config = run_config['dataset']
    train_config = run_config['training']

    # fill dims
    dataset_config['V'] = dims.V
    dataset_config['L'] = dims.L
    dataset_config['N_facts'] = dims.N_facts

    # optional: determine train set size
    if dataset_config['split_size'].get('train', None) is None:  # by default, scale by training number of steps
        train_split_size = train_config['max_num_steps'] * dataset_config['batch_size']
        dataset_config['split_size']['train'] = train_split_size

    # prepare dataloaders
    dataset_config_by_split = get_dataset_config_by_split_from_dataset_config(dataset_config, seed=dataloaders_seed)
    dataloaders = get_mqar_dynamic_dataloaders_by_split(dataset_config_by_split)

    return dataloaders
