"""Scenario registry mirroring the legacy submit experiments."""

from __future__ import annotations

from typing import Dict, Tuple

import numpy as np

from .biases import continuous_biases, monotone_biases, privacy_vectors, zero_biases
from .suite import LearningRateConfig, MethodSpec, ScenarioSpec, SuiteSpec

DEFAULT_CHAIN_MAP = {20000: 8, 100000: 10, 200000: 20}
DEFAULT_LAMBDA_GRID = tuple(np.insert(np.logspace(-4, 1, 6), -1, 3.0))

OPT_METHOD = MethodSpec(name="opt", strategy="opt", select_lambda=True)
INV_METHOD = MethodSpec(name="inv", strategy="opt", lambda_value=0.0)
MSE_METHOD = MethodSpec(name="mse", strategy="opt", lambda_value=1.0)
CONS_METHOD = MethodSpec(name="cons", strategy="cons", lambda_value=1.0)
CONSVAR_METHOD = MethodSpec(name="consvar", strategy="consvar", lambda_value=1.0)
TGT_METHOD = MethodSpec(name="tgt", strategy="target", z_score=6.74735)
DPSGD_METHOD = MethodSpec(name="dpsgd", strategy="opt", lambda_value=1.0)


def _monotone_strategy(step_small_factor: float, step_big_factor: float, n_sites: int):
    def _strategy(n_samples: int):
        scale = 1 / np.sqrt(n_samples)
        return monotone_biases(step_small_factor * scale, step_big_factor * scale, n_sites)

    return _strategy


def _continuous_strategy(end_val: float, num_points: int, n_sites: int):
    def _strategy(_: int):
        return continuous_biases(end_val, num_points, n_sites)

    return _strategy


def _gbias_bias_strategy(n_sites: int, num_points: int = 11):
    """Use different continuous ranges depending on the target sample size."""
    start_small = -5.0  # corresponds to exp(-5)
    end_small = 1.0
    start_large = -5.0 - float(np.log(4.0))  # corresponds to exp(-5) / 4
    end_large = 0.25

    def _strategy(n_samples: int):
        if n_samples >= 200000:
            return continuous_biases(end_large, num_points, n_sites, start_exp=start_large)
        return continuous_biases(end_small, num_points, n_sites, start_exp=start_small)

    return _strategy


def _zero_strategy(n_sites: int):
    def _strategy(_: int):
        return zero_biases(n_sites)

    return _strategy


def _constant_rs(n_sites: int, value: float = 0.5) -> Tuple[Tuple[float, ...], ...]:
    return (tuple([value] * n_sites),)


def _dpsgd_suite(lambda_grid=DEFAULT_LAMBDA_GRID) -> SuiteSpec:
    return SuiteSpec(
        name="dpsgd",
        mechanism="laplace",
        lambda_grid=lambda_grid,
        methods=(DPSGD_METHOD,),
    )


SCENARIOS: Dict[str, ScenarioSpec] = {
    "case_bias": ScenarioSpec(
        name="case_bias",
        description="Discrete bias grid with heterogeneous offsets.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25,),
        target_sample_sizes=(20000, 200000),
        source_props=(3.0,),
        bias_strategy=_monotone_strategy(0.1, 100.0, 4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD, INV_METHOD, MSE_METHOD),
            ),
            _dpsgd_suite(),
        ),
        output_name="case_bias.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_bias_target": ScenarioSpec(
        name="case_bias_target",
        description="Target-site focus under discrete biases.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25,),
        target_sample_sizes=(20000, 200000),
        source_props=(3.0,),
        bias_strategy=_monotone_strategy(0.1, 100.0, 4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(SuiteSpec(name="target", methods=(TGT_METHOD,)),),
        output_name="case_bias_target.csv",
        row_fields=("bias_id",),
        base_seed=2022,
        target_seed=511,
    ),
    "case_gbias": ScenarioSpec(
        name="case_gbias",
        description="Continuous bias grid for heterogeneous sites.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25, 0.5),
        target_sample_sizes=(20000, 200000),
        source_props=(3.0,),
        bias_strategy=_gbias_bias_strategy(4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD, MSE_METHOD, CONS_METHOD, TGT_METHOD),
            ),
        ),
        output_name="gbias.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_gbias_target": ScenarioSpec(
        name="case_gbias_target",
        description="Target-site evaluation under continuous biases.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25, 0.5),
        target_sample_sizes=(20000, 200000),
        source_props=(3.0,),
        bias_strategy=_gbias_bias_strategy(4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(SuiteSpec(name="target", methods=(TGT_METHOD,)),),
        output_name="gbias_target.csv",
        row_fields=("bias_id",),
        base_seed=2022,
        target_seed=511,
    ),
    "case_gbias_alt": ScenarioSpec(
        name="case_gbias_alt",
        description="Laplace-mechanism variant at τ=0.5.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.5,),
        target_sample_sizes=(20000,),
        source_props=(3.0,),
        bias_strategy=_continuous_strategy(1.0, 11, 4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(_dpsgd_suite(),),
        output_name="case_dpsgd_bias_tau0.5.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_bias_smallK": ScenarioSpec(
        name="case_bias_smallK",
        description="Discrete bias grid with reduced chain count.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25,),
        target_sample_sizes=(20004,),
        source_props=(3.0,),
        bias_strategy=_monotone_strategy(0.1, 100.0, 4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup={20004: 6},
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD, INV_METHOD, MSE_METHOD),
            ),
            _dpsgd_suite(),
        ),
        output_name="case_bias_smallK.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_gbias_changesites": ScenarioSpec(
        name="case_gbias_changesites",
        description="Continuous bias grid with adjustable site count.",
        n_sites=6,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25,),
        target_sample_sizes=(20000,),
        source_props=(3.0,),
        bias_strategy=_gbias_bias_strategy(6),
        rs_options=_constant_rs(6, 0.5),
        chain_lookup={20000: 8},
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD, INV_METHOD, MSE_METHOD, CONS_METHOD),
            ),
            _dpsgd_suite(),
        ),
        output_name="case_gbias_changesites.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_gbias_sensitiveana": ScenarioSpec(
        name="case_gbias_sensitiveana",
        description="Learning-rate sensitivity analysis for comparators.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25,),
        target_sample_sizes=(20000,),
        source_props=(3.0,),
        bias_strategy=_continuous_strategy(1.0, 11, 4),
        rs_options=_constant_rs(4, 0.5),
        chain_lookup={20000: 8},
        suites=(
            SuiteSpec(
                name="transfer-opt",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD,),
            ),
            SuiteSpec(
                name="transfer-comp",
                lambda_grid=(1.0,),
                lr=LearningRateConfig(c0=1.0, a=0.75, b0=0.0),
                methods=(INV_METHOD, MSE_METHOD, CONS_METHOD),
            ),
            _dpsgd_suite(),
        ),
        output_name="sensitive_ana_gbias_a_0.75.csv",
        row_fields=("bias_id",),
        base_seed=2022,
    ),
    "case_r": ScenarioSpec(
        name="case_r",
        description="Varying randomized-response rate among source sites.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25, 0.5, 0.75),
        target_sample_sizes=(20000,),
        source_props=(3.0,),
        bias_strategy=_zero_strategy(4),
        rs_options=tuple(tuple(rs) for rs in privacy_vectors(1.0, 0.25, 0.9, 9, 4)),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=DEFAULT_LAMBDA_GRID,
                methods=(OPT_METHOD, MSE_METHOD, TGT_METHOD),
            ),
            _dpsgd_suite(),
        ),
        output_name="case_r_dist_normal_r_1.csv",
        row_fields=("rs",),
        base_seed=2022,
    ),
    "case_r_consvar": ScenarioSpec(
        name="case_r_consvar",
        description="Conservative variance weights for varying r.",
        n_sites=4,
        n_sim=1000,
        dist_type="normal",
        taus=(0.25, 0.5, 0.75),
        target_sample_sizes=(20000,),
        source_props=(3.0,),
        bias_strategy=_zero_strategy(4),
        rs_options=tuple(tuple(rs) for rs in privacy_vectors(1.0, 0.25, 0.9, 9, 4)),
        chain_lookup=DEFAULT_CHAIN_MAP,
        suites=(
            SuiteSpec(
                name="transfer",
                lambda_grid=(1.0,),
                methods=(CONSVAR_METHOD, MSE_METHOD, TGT_METHOD),
            ),
        ),
        output_name="case_r_dist_normal_r_1_consvar.csv",
        row_fields=("rs",),
        base_seed=2022,
    ),
}

