"""Utilities for parameter sweeps and random sampling."""

import random
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
from omegaconf import DictConfig, OmegaConf

ConfigValue = Union[str, int, float, bool, List, Dict, Tuple[float, float]]
"""Any value that can be set in a configuration."""

NumericRange = Tuple[float, float]
"""A tuple representing a constant numeric range `(min, max)`."""

ScalarRangeConfig = Dict[str, Union[float, bool]]
"""Configuration for sampling scalar values with min/max/log_scale."""

DiscreteOptionsConfig = Dict[str, List]
"""Configuration for sampling from discrete options."""

RangeConfig = Union[ScalarRangeConfig, DiscreteOptionsConfig, NumericRange]
"""Any valid range configuration for parameter sampling."""

SampledValue = Union[str, int, float, bool, NumericRange]
"""A value sampled from a range configuration."""

SampledParams = Dict[str, Dict[str, SampledValue]]
"""Dictionary of sampled parameters organized by category."""

FixedParams = Dict[str, Dict[str, ConfigValue]]
"""Dictionary of fixed parameter overrides."""

ExperimentMetadata = Dict[str, Union[int, float, SampledParams]]
"""Metadata for a single experiment including sampled parameters."""

ExperimentConfig = Tuple[DictConfig, ExperimentMetadata]
"""A complete experiment configuration with its metadata."""


def set_nested_value(
	config: DictConfig, path: str, value: ConfigValue
) -> None:
	"""Set a nested value in a config using dot notation.

	Args:
		config (DictConfig): The config to modify
		path (str): Dot-separated path, e.g.,
			`"env.train_make_kwargs.period_range"`.
		value (ConfigValue): The value to set.
	"""
	keys = path.split(".")
	current = config

	for key in keys[:-1]:
		if key not in current:
			current[key] = {}
		current = current[key]

	current[keys[-1]] = value


def sample_from_range(
	range_config: RangeConfig,
	rng: Optional[np.random.Generator] = None,
) -> SampledValue:
	"""Sample a value from a parameter range configuration.

	Args:
		range_config (RangeConfig): Configuration defining the range.
		rng (Optional[np.random.Generator]): Random number generator. By
			default, `None`.

	Returns:
		SampledValue: Sampled value.

	Raises:
		ValueError: If range configuration is invalid.
	"""
	if rng is None:
		rng = np.random.default_rng()

	if isinstance(range_config, dict):
		if "min" in range_config and "max" in range_config:
			min_val = range_config["min"]
			max_val = range_config["max"]

			if not isinstance(min_val, (int, float)) or not isinstance(
				max_val, (int, float)
			):
				raise ValueError(
					"min and max values must be numeric for sampling"
				)

			if range_config.get("log_scale", False):
				log_min = np.log10(float(min_val))
				log_max = np.log10(float(max_val))
				return float(10 ** rng.uniform(log_min, log_max))
			else:
				return float(rng.uniform(float(min_val), float(max_val)))

		elif "options" in range_config:
			return rng.choice(range_config["options"])

	elif isinstance(range_config, (list, tuple)) and len(range_config) == 2:
		min_val, max_val = range_config
		sampled_value = float(rng.uniform(min_val, max_val))
		return sampled_value

	raise ValueError(f"Invalid range configuration: {range_config}")


def sample_random_parameters(
	sweep_config: DictConfig, seed: Optional[int] = None
) -> SampledParams:
	"""Sample random parameters for a single experiment.

	Args:
		sweep_config (DictConfig): Sweep configuration.
		seed (Optional[int]): Random seed, if any. By default, `None`.

	Returns:
		SampledParams: Dictionary of sampled parameters.
	"""
	if seed is not None:
		rng = np.random.default_rng(seed)
		random.seed(seed)
	else:
		rng = np.random.default_rng()

	sampled_params = {}

	if "random_params" in sweep_config:
		for category, params in sweep_config.random_params.items():
			if category not in sampled_params:
				sampled_params[category] = {}

			for param_name, range_config in params.items():
				sampled_value = sample_from_range(range_config, rng)
				sampled_params[category][param_name] = sampled_value

	return sampled_params


def apply_sampled_parameters(
	base_config: DictConfig,
	sampled_params: SampledParams,
	fixed_params: Optional[FixedParams] = None,
) -> DictConfig:
	"""Apply sampled and fixed parameters to base configuration.

	Args:
		base_config (DictConfig): Base configuration to modify.
		sampled_params (SampledParams): Randomly sampled parameters.
		fixed_params (Optional[FixedParams]): Fixed parameter overrides.

	Returns:
		DictConfig: Modified configuration.
	"""
	config = OmegaConf.create(base_config)

	for category, params in sampled_params.items():
		for param_name, value in params.items():
			path = f"{category}.{param_name}"
			set_nested_value(config, path, value)

	if fixed_params:
		for category, params in fixed_params.items():
			for param_name, value in params.items():
				path = f"{category}.{param_name}"
				set_nested_value(config, path, value)

	return config


def calculate_asynchronicity_rate(period_range: NumericRange) -> float:
	"""Calculate expected asynchronicity rate from parameters.

	This is a rough estimate of signal reception density based only on
	period range and number of signals. A more accurate calculation
	could be retrieved empirically by simulating signal dynamics.

	Args:
		period_range (NumericRange): Range for signal renewal periods.

	Returns:
		float: Expected asynchronicity rate, with `0` corresponding to
			the fully synchronous case and `1` to the fully asynchronous
			one.
	"""
	avg_period = (period_range[0] + period_range[1]) / 2.0
	return min(1.0, 1.0 / avg_period) if avg_period > 0 else 0.0


def generate_sweep_experiment_configs(
	base_config: DictConfig, sweep_config: DictConfig
) -> List[ExperimentConfig]:
	"""Generate configurations for all experiments in a sweep.

	Args:
		base_config (DictConfig): Base experiment configuration.
		sweep_config (DictConfig): Sweep configuration.

	Returns:
		List[ExperimentConfig]: List of `(experiment_config, metadata)`
			tuples.
	"""
	experiments = []
	num_experiments = sweep_config.get("num_experiments", 10)

	for i in range(num_experiments):
		exp_seed = random.randint(0, 2**31 - 1)
		sampled_params = sample_random_parameters(sweep_config, exp_seed)

		exp_config = apply_sampled_parameters(
			base_config, sampled_params, sweep_config.get("fixed_params", {})
		)

		exp_config.seed = exp_seed
		exp_config.exp_name = f"sweep_exp_{i:03d}"

		metadata = {
			"experiment_id": i,
			"seed": exp_seed,
			"sampled_params": sampled_params,
		}

		if "env" in sampled_params:
			env_params = sampled_params["env"]
			if "period_range" in env_params:
				period_range_param = env_params.get("period_range")
				gamma_beta_range_param = env_params.get("gamma_beta_range")

				if (
					period_range_param is not None
					and isinstance(period_range_param, tuple)
					and len(period_range_param) == 2
				):
					if gamma_beta_range_param is None:
						gamma_beta_range_param = (1.0, 1.0)
					elif not (
						isinstance(gamma_beta_range_param, tuple)
						and len(gamma_beta_range_param) == 2
					):
						gamma_beta_range_param = (1.0, 1.0)

					async_rate = calculate_asynchronicity_rate(
						period_range_param
					)
					metadata["asynchronicity_rate"] = async_rate

		experiments.append((exp_config, metadata))

	return experiments
