import argparse
import copy
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import numpy as np
import yaml

from src.utils import dsigmoid
from src.experiments.runner import ExperimentBuilder, run as run_builder
from src.experiments.simulations import (
    LinearBanditSimulation,
    KernelBanditSimulation,
    ContextBanditSimulation,
)

try:
    from plot import plot_regret, plot_times  # type: ignore
except Exception:
    plot_regret = plot_times = None


@dataclass
class BanditTemplateResult:
    factory: Callable[[], Any]
    context: Dict[str, Any]


@dataclass
class AlgorithmTemplateResult:
    factory: Callable[[], Any]
    name: str


def sample_theta(rng: np.random.Generator, dim: int, bound: float) -> np.ndarray:
    theta = rng.uniform(-bound, bound, size=dim)
    norm = float(np.linalg.norm(theta))
    if norm > bound and norm > 0:
        theta *= bound / norm
    return theta


def ensure_directory(path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)


def generate_cb_policies(context: Dict[str, Any], count: int, seed: Optional[int]) -> Tuple[List[Any], List[np.ndarray]]:
    from src.algorithms.ADAILTCBp import LinearArgmaxPolicy

    rng = np.random.default_rng(seed)
    num_actions = context['num_actions']
    d_ctx = context['d_ctx']
    policies: List[Any] = []
    weight_bank: List[np.ndarray] = []
    for _ in range(count):
        W = rng.normal(size=(num_actions, d_ctx))
        weight_bank.append(W.copy())
        policies.append(LinearArgmaxPolicy(W.copy()))
    return policies, weight_bank


def _format_scalar(value: Any) -> str:
    if isinstance(value, float):
        text = f"{value:.6g}"
        return text.rstrip("0").rstrip(".") if "." in text else text
    return str(value)


def _sanitize_token(token: str) -> str:
    cleaned = "".join(ch if ch.isalnum() or ch in ("-",) else "_" for ch in str(token))
    return cleaned or "value"


def build_output_stem(builder: ExperimentBuilder) -> str:
    context = getattr(builder, "context", {})
    env_kwargs = builder.environment_kwargs
    bandit_name = context.get("bandit_name")
    if not bandit_name:
        try:
            bandit_name = builder.bandit_factory().__class__.__name__
        except Exception:
            bandit_name = "bandit"
    parts: List[str] = []
    sanitized_name = _sanitize_token(bandit_name)
    parts.append(sanitized_name)

    T = context.get("T") or env_kwargs.get("T")
    if T is not None:
        parts.append(f"T_{_format_scalar(T)}")

    if "d" in context and context["d"] is not None:
        parts.append(f"d_{_format_scalar(context['d'])}")
    if "d_ctx" in context and context["d_ctx"] is not None:
        parts.append(f"d_ctx_{_format_scalar(context['d_ctx'])}")

    num_actions = context.get("num_actions") or env_kwargs.get("num_actions")
    if num_actions is not None:
        parts.append(f"A_{_format_scalar(num_actions)}")

    xi = context.get("xi")
    continuous = context.get("continuous")
    if continuous is None:
        continuous = env_kwargs.get("continuous", False)
    if continuous:
        token = "cont"
        if xi is not None:
            token += f"_{_format_scalar(xi)}"
        parts.append(token)
    elif xi is not None:
        parts.append(f"xi_{_format_scalar(xi)}")

    parts.append(f"N_{_format_scalar(builder.n_mc)}")

    seed = context.get("seed")
    if seed is not None:
        parts.append(f"seed_{_format_scalar(seed)}")

    return "_".join(parts)





def bandit_glb_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'noise_variance', 'theta_bound', 'action_bound', 'mean_low', 'mean_high']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for GLB bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    noise = float(params['noise_variance'])
    theta_bound = float(params['theta_bound'])
    action_bound = float(params['action_bound'])
    mean_low = float(params['mean_low'])
    mean_high = float(params['mean_high'])
    reward_bound = float(params.get('reward_bound', 1.0))
    xi = params.get('xi')
    seed = params.get('seed')
    continuous = bool(params.get('continuous', False))

    rng = np.random.default_rng(seed)
    theta = params.get('theta')
    if theta is None:
        theta = sample_theta(rng, d, theta_bound)
    else:
        theta = np.asarray(theta, dtype=float)
    theta = theta.astype(float, copy=True)

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))

    from src.bandits.NSGenLinearBandit import NSGenLinearBandit

    def factory():
        return NSGenLinearBandit(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            theta=theta.copy(),
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            theta_bound=theta_bound,
            action_bound=action_bound,
            continuous=continuous,
        )

    S = theta_bound
    L = action_bound
    R = math.sqrt(noise)
    k_mu = 0.25
    c_mu = float(dsigmoid(L * S))
    delta = 1.0 / (2.0 * T)

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'noise_variance': noise,
        'S': S,
        'L': L,
        'R': R,
        'k_mu': k_mu,
        'c_mu': c_mu,
        'delta': delta,
        'r_lambda_glb': d,
        'r_lambda_weight': d / (c_mu ** 2),
        'reward_bound': reward_bound,
        'xi': xi,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': continuous,
        'bandit_name': 'NSGenLinearBandit',
    }
    return BanditTemplateResult(factory=factory, context=context)


def bandit_lb_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'noise_variance', 'theta_bound', 'action_bound']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for LB bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    noise = float(params['noise_variance'])
    theta_bound = float(params['theta_bound'])
    action_bound = float(params['action_bound'])
    reward_bound = float(params.get('reward_bound', theta_bound * action_bound))
    mean_low = float(params.get('mean_low', -1.0))
    mean_high = float(params.get('mean_high', 1.0))
    xi = params.get('xi')
    seed = params.get('seed')
    continuous = bool(params.get('continuous', False))

    rng = np.random.default_rng(seed)
    theta = params.get('theta')
    if theta is None:
        theta = sample_theta(rng, d, theta_bound)
    else:
        theta = np.asarray(theta, dtype=float)
    theta = theta.astype(float, copy=True)

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))

    from src.bandits.NSLinearBandit import NSLinearBandit

    def factory():
        return NSLinearBandit(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            theta=theta.copy(),
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            theta_bound=theta_bound,
            action_bound=action_bound,
            continuous=continuous,
        )

    S = theta_bound
    L = action_bound
    R = math.sqrt(noise)
    delta = 1.0 / (2.0 * T)

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'noise_variance': noise,
        'S': S,
        'L': L,
        'R': R,
        'delta': delta,
        'r_lambda': d,
        'reward_bound': reward_bound,
        'xi': xi,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': continuous,
        'bandit_name': 'NSLinearBandit',
    }
    return BanditTemplateResult(factory=factory, context=context)


def bandit_lb_ball_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'noise_variance', 'theta_bound', 'action_bound', 'radius']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for LB ball bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    noise = float(params['noise_variance'])
    theta_bound = float(params['theta_bound'])
    action_bound = float(params['action_bound'])
    radius = float(params['radius'])
    reward_bound = float(params.get('reward_bound', theta_bound * action_bound))
    xi = params.get('xi')
    seed = params.get('seed')

    rng = np.random.default_rng(seed)
    theta = params.get('theta')
    if theta is None:
        theta = sample_theta(rng, d, theta_bound)
    else:
        theta = np.asarray(theta, dtype=float)
    theta = theta.astype(float, copy=True)

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))
    if change_prob is None:
        change_prob = 0.0

    from src.bandits.NSLinearBanditBall import NSLinearBanditBall

    def factory():
        return NSLinearBanditBall(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            theta=theta.copy(),
            mean_low=params.get('mean_low', -1.0),
            mean_high=params.get('mean_high', 1.0),
            reward_bound=reward_bound,
            theta_bound=theta_bound,
            action_bound=action_bound,
            continuous=True,
            radius=radius,
            prob=change_prob,
        )

    S = theta_bound
    L = action_bound
    R = math.sqrt(noise)
    delta = 1.0 / (2.0 * T)

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'noise_variance': noise,
        'S': S,
        'L': L,
        'R': R,
        'delta': delta,
        'r_lambda': d,
        'reward_bound': reward_bound,
        'xi': xi,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': True,
        'bandit_name': 'NSLinearBanditBall',
    }
    return BanditTemplateResult(factory=factory, context=context)


def bandit_scb_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'noise_variance', 'theta_bound', 'action_bound']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for SCB bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    noise = float(params['noise_variance'])
    theta_bound = float(params['theta_bound'])
    action_bound = float(params['action_bound'])
    reward_bound = float(params.get('reward_bound', 1.0))
    mean_low = float(params.get('mean_low', -1.0))
    mean_high = float(params.get('mean_high', 1.0))
    xi = params.get('xi')
    seed = params.get('seed')
    continuous = bool(params.get('continuous', False))

    rng = np.random.default_rng(seed)
    theta = params.get('theta')
    if theta is None:
        theta = sample_theta(rng, d, theta_bound)
    else:
        theta = np.asarray(theta, dtype=float)
    theta = theta.astype(float, copy=True)

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))

    from src.bandits.NSSCBBandit import NSSCBBandit

    def factory():
        return NSSCBBandit(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            theta=theta.copy(),
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            theta_bound=theta_bound,
            action_bound=action_bound,
            continuous=continuous,
        )

    S = theta_bound
    L = action_bound
    R = 1.0
    k_mu = 0.25
    c_mu = float(dsigmoid(L * S))
    delta_scb = 1.0 / T
    dal_delta = T ** (-1.0 / 6.0)
    r_lambda_master = d * math.log(T)
    r_lambda_scb = d * math.log(T) / (4.0 * c_mu) if c_mu > 0 else d

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'noise_variance': noise,
        'S': S,
        'L': L,
        'R': R,
        'k_mu': k_mu,
        'c_mu': c_mu,
        'delta_master': 2.0 * delta_scb,
        'delta_scb': delta_scb,
        'dal_delta': dal_delta,
        'r_lambda_master': r_lambda_master,
        'r_lambda_scb': r_lambda_scb,
        'reward_bound': reward_bound,
        'failure_delta': delta_scb,
        'xi': xi,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': continuous,
        'bandit_name': 'NSSCBBandit',
    }
    return BanditTemplateResult(factory=factory, context=context)


def bandit_kb_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'noise_variance', 'reward_bound']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for KB bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    noise = float(params['noise_variance'])
    reward_bound = float(params['reward_bound'])
    mean_low = float(params.get('mean_low', -1.0))
    mean_high = float(params.get('mean_high', 1.0))
    seed = params.get('seed')
    xi = params.get('xi')
    reward_method = params.get('reward_generation_method', 'kernel_sum')

    tol = float(params.get('tol', 0.1))
    window = int(params.get('window', 0))
    kernel_length_scale = float(params.get('kernel_length_scale', 0.2))

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))

    from src.bandits.NSKernelBandit import NSKernelBandit

    def factory():
        return NSKernelBandit(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            mean_low=mean_low,
            mean_high=mean_high,
            reward_bound=reward_bound,
            reward_generation_method=reward_method,
        )

    kernel_config = {
        'tol': tol,
        'window': window,
        'kernel': {'name': 'rbf', 'length_scale': kernel_length_scale},
        'lambda': noise,
        'v': math.sqrt(noise),
        'seed': seed,
    }

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'noise_variance': noise,
        'reward_bound': reward_bound,
        'xi': xi,
        'kernel_config': kernel_config,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': bool(params.get('continuous', False)),
        'bandit_name': 'NSKernelBandit',
    }
    return BanditTemplateResult(factory=factory, context=context)


def bandit_cb_template(params: Dict[str, Any]) -> BanditTemplateResult:
    required = ['T', 'num_actions', 'd', 'd_ctx', 'noise_variance', 'reward_bound']
    for key in required:
        if key not in params:
            raise ValueError(f"Missing parameter '{key}' for CB bandit.")

    T = int(params['T'])
    num_actions = int(params['num_actions'])
    d = int(params['d'])
    d_ctx = int(params['d_ctx'])
    noise = float(params['noise_variance'])
    reward_bound = float(params['reward_bound'])
    action_bound = float(params.get('action_bound', 1.0))
    context_mode = params.get('context_mode', 'finite')
    xi = params.get('xi')
    seed = params.get('seed')
    continuous = bool(params.get('continuous', False))

    change_prob = params.get('change_prob')
    if change_prob is None and xi is not None:
        change_prob = float(T ** (-float(xi)))

    optional_keys = [
        'finite_contexts',
        'finite_probs',
        'gaussian_mean',
        'gaussian_scale',
        'finite_N_min',
        'finite_N_max',
        'ctx_mean_std',
        'ctx_scale_low',
        'ctx_scale_high',
        'mean_low',
        'mean_high',
    ]

    from src.bandits.NSContextBandit import NSContextBandit

    def factory():
        kwargs = dict(
            num_actions=num_actions,
            noise_variance=noise,
            d=d,
            d_c=d_ctx,
            reward_bound=reward_bound,
            context_mode=context_mode,
            action_bound=action_bound,
            continuous=continuous,
            seed=seed,
        )
        for key in optional_keys:
            if key in params:
                kwargs[key] = params[key]
        return NSContextBandit(**kwargs)

    delta = 1.0 / math.sqrt(T)
    policy_count = int(params.get('policy_count', 100))

    context = {
        'T': T,
        'num_actions': num_actions,
        'd': d,
        'd_ctx': d_ctx,
        'noise_variance': noise,
        'reward_bound': reward_bound,
        'action_bound': action_bound,
        'delta': delta,
        'policy_count': policy_count,
        'xi': xi,
        'seed': seed,
        'change_prob': change_prob,
        'continuous': continuous,
        'context_mode': context_mode,
        'bandit_name': 'NSContextBandit',
    }
    return BanditTemplateResult(factory=factory, context=context)


BANDIT_TEMPLATES: Dict[str, Callable[[Dict[str, Any]], BanditTemplateResult]] = {
    'glb': bandit_glb_template,
    'lb': bandit_lb_template,
    'lb_ball': bandit_lb_ball_template,
    'scb': bandit_scb_template,
    'kb': bandit_kb_template,
    'cb': bandit_cb_template,
}





def algorithm_master_glb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.MASTER import MASTER

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'model': 'GLB',
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'MASTER(GLB)')

    def factory():
        alg = MASTER(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)

def _build_dal_glb_algorithm(
    context: Dict[str, Any],
    overrides: Optional[Dict[str, Any]],
    dal_cls,
    default_name: str,
) -> AlgorithmTemplateResult:
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'noise_variance': context['noise_variance'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', default_name)

    def factory():
        alg = dal_cls(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_dal_glb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_GLBs import DAL_GLB

    return _build_dal_glb_algorithm(context, overrides, DAL_GLB, 'DAL(GLB)')


def algorithm_dal_glb_gsr(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_GLBs_GSR import DAL_GLB_GSR

    return _build_dal_glb_algorithm(context, overrides, DAL_GLB_GSR, 'DAL(GLB)-GSR')


def algorithm_glb_weight_ucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GLB_WeightUCB_new import GLB_WeightUCB

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_weight'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'GLB-WeightUCB')

    def factory():
        alg = GLB_WeightUCB(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_master_lb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.MASTER import MASTER

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'model': 'LB',
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'MASTER(LB)')

    def factory():
        alg = MASTER(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def _build_dal_lb_algorithm(
    context: Dict[str, Any],
    overrides: Optional[Dict[str, Any]],
    dal_cls,
    default_name: str,
) -> AlgorithmTemplateResult:
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'noise_variance': context['noise_variance'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', default_name)

    def factory():
        alg = dal_cls(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_dal_lb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_LBs import DAL_LB

    return _build_dal_lb_algorithm(context, overrides, DAL_LB, 'DAL(LB)')


def algorithm_dal_lb_gsr(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_LBs_GSR import DAL_LB_GSR

    return _build_dal_lb_algorithm(context, overrides, DAL_LB_GSR, 'DAL(LB)-GSR')


def algorithm_lb_weight_ucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.LB_WeightUCB_new import LB_WeightUCB

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'LB-WeightUCB')

    def factory():
        alg = LB_WeightUCB(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_opkb_lin(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.OPKBLin import OPKB

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'B': context['reward_bound'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'ADA-OPKB (Lin)')

    def factory():
        alg = OPKB(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)
def algorithm_master_scb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.MASTER import MASTER

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta_master'],
        'r_lambda': context['r_lambda_master'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'model': 'SCB',
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'MASTER(SCB)')

    def factory():
        alg = MASTER(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_scb_weight_ucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.SCB_WeightUCB_new import SCB_WeightUCB

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta_scb'],
        'r_lambda': context['r_lambda_scb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'SCB-WeightUCB')

    def factory():
        alg = SCB_WeightUCB(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def _build_dal_ofu_algorithm(
    context: Dict[str, Any],
    overrides: Optional[Dict[str, Any]],
    dal_cls,
    default_name: str,
) -> AlgorithmTemplateResult:
    from src.algorithms.OFUGLB_fast import OFUGLB

    cfg = overrides or {}
    base_kwargs = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'dim': context['d'],
        'param_norm_ub': context['S'],
        'arm_norm_ub': context['L'],
        'failure_level': context['failure_delta'],
    }
    base_kwargs.update(cfg.get('base_params', {}))

    def base_factory():
        return OFUGLB(**base_kwargs)

    params = {
        'T': context['T'],
        'delta': context['dal_delta'],
        'base_factory': base_factory,
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', default_name)

    def factory():
        alg = dal_cls(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_dal_ofu(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_SCB import DAL_SCB

    return _build_dal_ofu_algorithm(context, overrides, DAL_SCB, 'DAL-SCB(OFU-GLB)')


def algorithm_dal_ofu_gsr(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_SCB_GSR import DAL_SCB_GSR

    return _build_dal_ofu_algorithm(context, overrides, DAL_SCB_GSR, 'DAL-SCB_GSR(OFU-GLB)')


def algorithm_opkb_kb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.OPKB import OPKB

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'B': context['reward_bound'],
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'ADA-OPKB')

    def factory():
        alg = OPKB(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_gp_ucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GP_UCB_W import GP_UCB_W

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'xi': context.get('xi', 0.5),
        'B': context['reward_bound'],
        'config': copy.deepcopy(context['kernel_config']),
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'GP-UCB-W')

    def factory():
        cfg_params = copy.deepcopy(params)
        cfg_params['config'] = copy.deepcopy(params['config'])
        alg = GP_UCB_W(**cfg_params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_gp_ucb_r(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GP_UCB_R import GP_UCB_R

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'xi': context.get('xi', 0.5),
        'B': context['reward_bound'],
        'config': copy.deepcopy(context['kernel_config']),
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'GP-UCB-R')

    def factory():
        cfg_params = copy.deepcopy(params)
        cfg_params['config'] = copy.deepcopy(params['config'])
        alg = GP_UCB_R(**cfg_params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_gp_ucb_sw(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GP_UCB_SW import GP_UCB_SW

    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'xi': context.get('xi', 0.5),
        'B': context['reward_bound'],
        'config': copy.deepcopy(context['kernel_config']),
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', 'GP-UCB-SW')

    def factory():
        cfg_params = copy.deepcopy(params)
        cfg_params['config'] = copy.deepcopy(params['config'])
        alg = GP_UCB_SW(**cfg_params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)



def algorithm_lb_restart(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.LB_RestartUCB import LB_RestartUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'H': cfg.get('H'), # if None, it will auto-tune using PT
    }
    name = cfg.get('name', 'LB-RestartUCB')
    def factory():
        alg = LB_RestartUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_lb_window(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.LB_WindowUCB import LB_WindowUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'w': cfg.get('w'), # if None, it will auto-tune using PT
    }
    name = cfg.get('name', 'LB-WindowUCB')
    def factory():
        alg = LB_WindowUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_lb_dlin(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.LB_DLinUCB import LB_DLinUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda'],
        'S': context['S'],
        'R': context['R'],
        'gamma': cfg.get('gamma'), # if None, it will auto-tune using PT
    }
    name = cfg.get('name', 'LB-DLinUCB')
    def factory():
        alg = LB_DLinUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)


def _build_dal_kb_algorithm(
    context: Dict[str, Any],
    overrides: Optional[Dict[str, Any]],
    dal_cls,
    default_name: str,
) -> AlgorithmTemplateResult:
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'noise_variance': context['noise_variance'],
        'B': context['reward_bound'],
        'config': copy.deepcopy(context['kernel_config']),
    }
    params.update(cfg.get('params', {}))
    name = cfg.get('name', default_name)

    def factory():
        cfg_params = copy.deepcopy(params)
        cfg_params['config'] = copy.deepcopy(params['config'])
        alg = dal_cls(**cfg_params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_glb_restart(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GLB_RestartUCB import GLB_RestartUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'H': cfg.get('H'), # Auto-tunes if None
    }
    name = cfg.get('name', 'GLB-RestartUCB')
    def factory():
        alg = GLB_RestartUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_glb_dglucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GLB_dGLUCB import GLB_dGLUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'gamma': cfg.get('gamma'), # Auto-tunes if None
    }
    name = cfg.get('name', 'GLB-dGLUCB')
    def factory():
        alg = GLB_dGLUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_glb_swglucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GLB_SWGLUCB import GLB_SWGLUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'w': cfg.get('w'), # Auto-tunes if None
    }
    name = cfg.get('name', 'GLB-SWGLUCB')
    def factory():
        alg = GLB_SWGLUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_glb_bvd(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.GLB_BVD_GLM_UCB import GLB_BVD_GLM_UCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta'],
        'r_lambda': context['r_lambda_glb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'gamma': cfg.get('gamma'), # Auto-tunes if None
    }
    name = cfg.get('name', 'GLB-BVD-GLM-UCB')
    def factory():
        alg = GLB_BVD_GLM_UCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)



def algorithm_scb_dglucb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.SCB_dGLUCB import SCB_dGLUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        # FIX: Use 'delta_scb' instead of 'delta'
        'delta': context['delta_scb'], 
        'r_lambda': context['r_lambda_master'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'gamma': cfg.get('gamma'), 
    }
    name = cfg.get('name', 'SCB-dGLUCB')
    
    def factory():
        alg = SCB_dGLUCB(**params)
        setattr(alg, 'name', name)
        return alg
        
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_scb_restart(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.SCB_RestartUCB import SCB_RestartUCB
    cfg = overrides or {}
    params = {
        'num_actions': context['num_actions'],
        'horizon': context['T'],
        'd': context['d'],
        'delta': context['delta_scb'], # Note: Uses SCB specific delta if available, or standard
        'r_lambda': context['r_lambda_scb'],
        'S': context['S'],
        'L': context['L'],
        'R': context['R'],
        'k_mu': context['k_mu'],
        'c_mu': context['c_mu'],
        'H': cfg.get('H'), # None enables auto-tune
    }
    name = cfg.get('name', 'SCB-RestartUCB')
    def factory():
        alg = SCB_RestartUCB(**params)
        setattr(alg, 'name', name)
        return alg
    return AlgorithmTemplateResult(factory=factory, name=name)

def algorithm_dal_kb(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_KBs import DAL_KB

    return _build_dal_kb_algorithm(context, overrides, DAL_KB, 'DAL(KB)')


def algorithm_dal_kb_gsr(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_KBs_GSR import DAL_KB_GSR

    return _build_dal_kb_algorithm(context, overrides, DAL_KB_GSR, 'DAL(KB)-GSR')


def _build_dal_context_algorithm(
    context: Dict[str, Any],
    overrides: Optional[Dict[str, Any]],
    dal_cls,
    default_names: Dict[str, str],
    fallback: str,
) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_CBs import RegCB, SquareCB, Cover

    cfg = overrides or {}
    base_name = cfg.get('base', 'square').lower()
    base_params = cfg.get('base_params', {})
    seed = cfg.get('seed', context.get('seed'))

    def base_factory():
        kwargs = dict(num_actions=context['num_actions'], horizon=context['T'])
        if seed is not None:
            kwargs['seed'] = seed
        if base_name == 'regcb':
            defaults = {'mode': 'elimination'}
        elif base_name == 'cover':
            defaults = {'m': 24, 'psi': 1.0, 'nounif': True, 'cb_type': 'mtr'}
        else:
            defaults = {'gamma_scale': 200.0, 'gamma_exponent': 1.0, 'elim': False, 'mellowness': 0.005}
        defaults.update(base_params)
        kwargs.update(defaults)
        if base_name == 'regcb':
            return RegCB(**kwargs)
        if base_name == 'cover':
            return Cover(**kwargs)
        return SquareCB(**kwargs)

    delta = cfg.get('delta', context['delta'])
    explore_coef = cfg.get('explore_coef')
    rng_seed = cfg.get('rng_seed', seed)
    fallback_name = fallback.format(base=base_name)
    name = cfg.get('name', default_names.get(base_name, fallback_name))

    def factory():
        params = dict(
            T=context['T'],
            delta=delta,
            noise_variance=context['noise_variance'],
            base_factory=base_factory,
        )
        if rng_seed is not None:
            params['rng'] = np.random.default_rng(rng_seed)
        if explore_coef is not None:
            params['explore_coef'] = explore_coef
        alg = dal_cls(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_dal_context(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_CBs import DALContext

    default_names = {'square': 'DAB(Square)', 'regcb': 'DAB(RegCB)', 'cover': 'DAB(Cover)'}
    return _build_dal_context_algorithm(context, overrides, DALContext, default_names, 'DALContext({base})')


def algorithm_dal_context_gsr(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.DAL_CBs_GSR import DALContextGSR

    default_names = {'square': 'DAB_GSR(Square)', 'regcb': 'DAB_GSR(RegCB)', 'cover': 'DAB_GSR(Cover)'}
    return _build_dal_context_algorithm(context, overrides, DALContextGSR, default_names, 'DALContextGSR({base})')


def algorithm_adailtcb_plus(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.ADAILTCBp import ADAILTCBPlusBandit, LinearArgmaxPolicy

    cfg = overrides or {}
    erm_mode = cfg.get('erm_mode', 'vw').lower()
    default_name = 'ADA-ILTCB+ VW-ERM' if erm_mode == 'vw' else 'ADA-ILTCB+ finite'
    name = cfg.get('name', default_name)
    policy_count = int(cfg.get('policy_count', context.get('policy_count', 100)))
    seed = cfg.get('seed', context.get('seed'))
    delta = cfg.get('delta', 0.05)

    supplied = cfg.get('policies')
    base_policy_weights: List[np.ndarray] = []
    if supplied is None:
        _, weight_bank = generate_cb_policies(context, policy_count, seed)
        base_policy_weights = [np.asarray(w, float).copy() for w in weight_bank]
    else:
        for item in supplied:
            if hasattr(item, 'W'):
                base_policy_weights.append(np.asarray(item.W, float).copy())
            else:
                base_policy_weights.append(np.asarray(item, float).copy())

    vw_defaults = {'passes': 1, 'l2': 1e-8, 'learning_rate': 0.5, 'interactions': ('ax',)}
    vw_config = dict(vw_defaults)
    vw_config.update(cfg.get('vw_erm_kwargs', {}))

    def factory():
        params = dict(
            num_actions=context['num_actions'],
            horizon=context['T'],
            erm_mode='finite' if erm_mode == 'finite' else 'vw',
            delta=delta,
            featurize_ctx=None,
        )
        rng_seed = cfg.get('rng_seed')
        if rng_seed is not None:
            params['rng'] = np.random.default_rng(rng_seed)
        policy_bank = [LinearArgmaxPolicy(np.asarray(w, float).copy()) for w in base_policy_weights]
        if not policy_bank:
            raise ValueError("ADAILTCB+ requires a non-empty policy bank.")
        if erm_mode == 'finite':
            params['policies'] = policy_bank
        else:
            params['vw_erm_kwargs'] = dict(vw_config)
            params['policies'] = policy_bank
        alg = ADAILTCBPlusBandit(**params)
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)


def algorithm_master_context(context: Dict[str, Any], overrides: Optional[Dict[str, Any]] = None) -> AlgorithmTemplateResult:
    from src.algorithms.MASTERContext import MASTERContext
    from src.algorithms.DAL_CBs import Cover, SquareCB, RegCB
    from src.algorithms.ADAILTCBp import LinearArgmaxPolicy

    cfg = overrides or {}
    base_name = cfg.get('base', 'cover').lower()
    base_params = cfg.get('base_params', {})
    policy_count = int(cfg.get('policy_count', context.get('policy_count', 100)))
    seed = cfg.get('seed', context.get('seed'))
    delta = cfg.get('delta', 0.05)
    c2 = cfg.get('c2', 2.0)
    default_names = {'cover': 'MASTER (Cover base, finite Pi)', 'square': 'MASTER (Square base)', 'regcb': 'MASTER (RegCB base)'}
    name = cfg.get('name', default_names.get(base_name, f'MASTERContext({base_name})'))

    supplied_policies = cfg.get('policies')
    master_policy_weights: List[np.ndarray] = []
    if supplied_policies is None:
        policy_seed = None if seed is None else seed + cfg.get('policy_seed_offset', 123)
        _, weight_bank = generate_cb_policies(context, policy_count, policy_seed)
        master_policy_weights = [np.asarray(w, float).copy() for w in weight_bank]
    else:
        for item in supplied_policies:
            if hasattr(item, 'W'):
                master_policy_weights.append(np.asarray(item.W, float).copy())
            else:
                master_policy_weights.append(np.asarray(item, float).copy())

    def base_factory():
        kwargs = dict(num_actions=context['num_actions'], horizon=context['T'])
        if seed is not None:
            kwargs['seed'] = seed
        if base_name == 'regcb':
            defaults = {'mode': 'elimination'}
        elif base_name == 'square':
            defaults = {'gamma_scale': 200.0, 'gamma_exponent': 1.0, 'elim': False, 'mellowness': 0.005}
        else:
            defaults = {'m': 8, 'psi': 1.0, 'nounif': False}
        defaults.update(base_params)
        kwargs.update(defaults)
        if base_name == 'regcb':
            return RegCB(**kwargs)
        if base_name == 'square':
            return SquareCB(**kwargs)
        return Cover(**kwargs)

    def factory():
        if not master_policy_weights:
            raise ValueError("MASTERContext requires a non-empty policy bank.")
        policies = [LinearArgmaxPolicy(np.asarray(w, float).copy()) for w in master_policy_weights]
        alg = MASTERContext(
            num_actions=context['num_actions'],
            horizon=context['T'],
            delta=delta,
            base_factory=base_factory,
            policies=policies,
            c2=c2,
            seed=seed,
        )
        setattr(alg, 'name', name)
        return alg

    return AlgorithmTemplateResult(factory=factory, name=name)

ALGORITHM_TEMPLATES: Dict[str, Callable[[Dict[str, Any], Optional[Dict[str, Any]]], AlgorithmTemplateResult]] = {
    'master_glb': algorithm_master_glb,
    'dal_glb': algorithm_dal_glb,
    'dal_glb_gsr': algorithm_dal_glb_gsr,
    'glb_weight_ucb': algorithm_glb_weight_ucb,
    'master_lb': algorithm_master_lb,
    'dal_lb': algorithm_dal_lb,
    'dal_lb_gsr': algorithm_dal_lb_gsr,
    'lb_weight_ucb': algorithm_lb_weight_ucb,
    'opkb_lin': algorithm_opkb_lin,
    'master_scb': algorithm_master_scb,
    'scb_weight_ucb': algorithm_scb_weight_ucb,
    'dal_ofu_glb': algorithm_dal_ofu,
    'dal_ofu_glb_gsr': algorithm_dal_ofu_gsr,
    'opkb_kb': algorithm_opkb_kb,
    'gp_ucb': algorithm_gp_ucb,
    'dal_kb': algorithm_dal_kb,
    'dal_kb_gsr': algorithm_dal_kb_gsr,
    'dal_context': algorithm_dal_context,
    'dal_context_gsr': algorithm_dal_context_gsr,
    'adailtcb_plus': algorithm_adailtcb_plus,
    'master_context': algorithm_master_context,
    'lb_restart': algorithm_lb_restart,
    'lb_window': algorithm_lb_window,
    'lb_dlin': algorithm_lb_dlin,
    'glb_restart': algorithm_glb_restart,
    'glb_dglucb': algorithm_glb_dglucb,
    'glb_swglucb': algorithm_glb_swglucb,
    'glb_bvd': algorithm_glb_bvd,
    'scb_dglb': algorithm_scb_dglucb,
    'scb_restart': algorithm_scb_restart,
    'gp_ucb_r': algorithm_gp_ucb_r,
    'gp_ucb_sw': algorithm_gp_ucb_sw,
}

SIMULATION_ALIASES = {
    'linear': LinearBanditSimulation,
    'kernel': KernelBanditSimulation,
    'contextual': ContextBanditSimulation,
}
def parse_algorithm_entry(entry: Any, context: Dict[str, Any]) -> AlgorithmTemplateResult:
    if isinstance(entry, str):
        template = entry
        overrides = None
    elif isinstance(entry, dict):
        template = entry.get('template')
        if not template:
            raise ValueError("Algorithm entry must include a 'template'.")
        overrides = entry.get('overrides')
    else:
        raise ValueError(f"Invalid algorithm entry: {entry!r}")

    if template not in ALGORITHM_TEMPLATES:
        raise ValueError(f"Unknown algorithm template '{template}'.")
    return ALGORITHM_TEMPLATES[template](context, overrides)


def build_bandit(cfg: Dict[str, Any], default_problem: Optional[str]) -> BanditTemplateResult:
    bandit_section = cfg.get('bandit', {})
    template_name = bandit_section.get('template', default_problem)
    if not template_name:
        raise ValueError('Bandit template must be provided.')
    if template_name not in BANDIT_TEMPLATES:
        raise ValueError(f"Unknown bandit template '{template_name}'.")

    parameters = dict(cfg.get('parameters', {}))
    parameters.update(bandit_section.get('params', {}))
    return BANDIT_TEMPLATES[template_name](parameters)


def resolve_simulation(sim_key: Optional[str]):
    if sim_key is None:
        return LinearBanditSimulation
    if sim_key in SIMULATION_ALIASES:
        return SIMULATION_ALIASES[sim_key]
    module, _, name = sim_key.rpartition('.')
    if not module:
        raise ValueError(f"Unknown simulation alias '{sim_key}'.")
    mod = __import__(module, fromlist=[name])
    return getattr(mod, name)


def build_experiment(name: str, cfg: Dict[str, Any]) -> ExperimentBuilder:
    problem = cfg.get('problem')
    bandit_result = build_bandit(cfg, problem)
    context = bandit_result.context

    env_cfg = dict(cfg.get('environment', {}))
    env_cfg.setdefault('T', context['T'])
    env_cfg.setdefault('continuous', context.get('continuous', False))
    env_cfg.setdefault('reward_bound', context.get('reward_bound', 1.0))
    env_cfg.setdefault('change_prob', context.get('change_prob'))

    simulation_key = cfg.get('simulation')
    env_cfg['simulation_cls'] = resolve_simulation(simulation_key)

    algorithm_entries = cfg.get('algorithms', [])
    algorithm_results = [parse_algorithm_entry(entry, context) for entry in algorithm_entries]
    algorithm_factories = [res.factory for res in algorithm_results]
    algorithm_names = [res.name for res in algorithm_results]

    runs_cfg = cfg.get('runs', {})
    n_mc = int(runs_cfg.get('n_mc', 1))
    n_jobs = int(runs_cfg.get('n_jobs', -1))

    output_cfg = cfg.get('output', {})

    builder = ExperimentBuilder(
        bandit_factory=bandit_result.factory,
        algorithm_factories=algorithm_factories,
        algorithm_names=algorithm_names,
        environment_kwargs=env_cfg,
        n_mc=n_mc,
        n_jobs=n_jobs,
    )
    builder.output = output_cfg
    builder.context = context
    builder.problem = problem
    return builder


def load_config(path: Path) -> Dict[str, Any]:
    data = yaml.safe_load(path.read_text())
    if not isinstance(data, dict) or 'experiments' not in data:
        raise ValueError("Configuration must define a top-level 'experiments' mapping.")
    experiments = data['experiments']
    if not isinstance(experiments, dict):
        raise ValueError("'experiments' must be a mapping.")
    return experiments


def plot_results(stem: str, summary: Dict[str, Any], horizon: int) -> None:
    if plot_regret is None or plot_times is None:
        return
    ensure_directory(Path('fig') / 'placeholder')
    plot_regret(summary['avg_regret'], summary['std_regret'], horizon, exp_name=stem)
    plot_times(summary['timings'], exp_name=stem)


def save_results(path: Path, payload: Dict[str, Any]) -> None:
    import pickle

    ensure_directory(path)
    with path.open('wb') as fh:
        pickle.dump(payload, fh)
def main(argv: Optional[Iterable[str]] = None) -> None:
    parser = argparse.ArgumentParser(description='Run DAL experiments from YAML configuration.')
    parser.add_argument('config', type=Path, help='Path to the configuration YAML file.')
    parser.add_argument('--experiment', '-e', action='append', dest='experiments', help='Experiment name to run (can repeat).')
    parser.add_argument('--list', action='store_true', help='List available experiments and exit.')
    args = parser.parse_args(list(argv) if argv is not None else None)

    experiments = load_config(args.config)

    if args.list:
        for key in experiments:
            print(key)
        return

    selected = args.experiments or list(experiments.keys())
    unknown = [name for name in selected if name not in experiments]
    if unknown:
        raise ValueError(f"Unknown experiment(s): {', '.join(unknown)}")

    for name in selected:
        cfg = experiments[name]
        builder = build_experiment(name, cfg)
        summary = run_builder(builder)
        horizon = builder.environment_kwargs.get('T', 0)
        output_stem = build_output_stem(builder)

        output_cfg = getattr(builder, 'output', {})
        if output_cfg.get('plot', True):
            plot_results(output_stem, summary, horizon)

        save_spec = output_cfg.get('save')
        if save_spec:
            base_path = Path(save_spec)
            results_dir = base_path if base_path.suffix != '.pkl' else base_path.parent
            if not str(results_dir):
                results_dir = Path('.')
            save_results(results_dir / f"{output_stem}.pkl", summary)

        final_regrets = {alg: float(reg[-1]) for alg, reg in summary['avg_regret'].items()}
        print(f"Completed experiment '{name}' -> final regrets: {final_regrets}")


if __name__ == '__main__':
    main()
