from copy import deepcopy
from se.configs import TrainConfig, ModelConfig, PROJECT_ROOT


def _set_s_samples_per_epoch(cfg: TrainConfig) -> TrainConfig:
    # Keep steps/epoch stable if batch_size changes after init
    cfg.s_samples_per_epoch = cfg.s_steps_per_epoch * cfg.batch_size
    return cfg


def _with_noise(cfg: TrainConfig, sigma_8bit: float, noise_type: str) -> TrainConfig:
    cfg.min_noise = sigma_8bit
    cfg.max_noise = sigma_8bit
    cfg.noise_type = noise_type  # type: ignore[attr-defined]
    return cfg


# %% Gaussian dncnn experiments
cfg = TrainConfig(
    train_dataset_type="h",
    test_path=[f"{PROJECT_ROOT}/data/Set12"],
    loss_type="l2",
    lr=1e-4,
    lr_halving_epochs=None,
    lr_halving_steps=None,
    s_patch_size=70,
)

# 50 noise level
cfg_50 = deepcopy(cfg)
cfg_50.min_noise = 50.0
cfg_50.max_noise = 50.0

## 50 noise level, FDnCNN model, NE
cfg_50_fdncnn_ne = deepcopy(cfg_50)
cfg_50_fdncnn_ne.model = "fdncnn"
cfg_50_fdncnn_ne.num_steps = 900_000
cfg_50_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)

## 50 noise level, FDnCNN model, SE
cfg_50_fdncnn_se = deepcopy(cfg_50)
cfg_50_fdncnn_se.model = "fdncnn"
cfg_50_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_50_fdncnn_se.loss_type = "l1"

## 50 noise level, FDnCNN model, O
cfg_50_fdncnn_o = deepcopy(cfg_50)
cfg_50_fdncnn_o.model = "fdncnn"
cfg_50_fdncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_50_fdncnn_o.loss_type = "l1"

## 50 noise level, DnCNN model, WNE
cfg_50_dncnn_wne = deepcopy(cfg_50)
cfg_50_dncnn_wne.model = "dncnn"
cfg_50_dncnn_wne.lr_halving_steps = int(1e5)
cfg_50_dncnn_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)

## 50 noise level, DnCNN model, O
cfg_50_dncnn_o = deepcopy(cfg_50)
cfg_50_dncnn_o.model = "dncnn"
cfg_50_dncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_50_dncnn_o.loss_type = "l1"

# 25 noise level
cfg_25 = deepcopy(cfg)
cfg_25.min_noise = 25.0
cfg_25.max_noise = 25.0

## 25 noise level, FDnCNN model, NE
cfg_25_fdncnn_ne = deepcopy(cfg_25)
cfg_25_fdncnn_ne.model = "fdncnn"
cfg_25_fdncnn_ne.num_steps = 900_000
cfg_25_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)
## 25 noise level, FDnCNN model, SE
cfg_25_fdncnn_se = deepcopy(cfg_25)
cfg_25_fdncnn_se.model = "fdncnn"
cfg_25_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_25_fdncnn_se.loss_type = "l1"

## 25 noise level, FDnCNN model, O
cfg_25_fdncnn_o = deepcopy(cfg_25)
cfg_25_fdncnn_o.model = "fdncnn"
cfg_25_fdncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_25_fdncnn_o.loss_type = "l1"

## 25 noise level, DnCNN model, WNE
cfg_25_dncnn_wne = deepcopy(cfg_25)
cfg_25_dncnn_wne.model = "dncnn"
cfg_25_dncnn_wne.lr_halving_steps = int(1e5)
cfg_25_dncnn_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)

## 25 noise level, DnCNN model, O
cfg_25_dncnn_o = deepcopy(cfg_25)
cfg_25_dncnn_o.model = "dncnn"
cfg_25_dncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_25_dncnn_o.loss_type = "l1"

# 10 noise level
cfg_10 = deepcopy(cfg)
cfg_10.min_noise = 10.0
cfg_10.max_noise = 10.0

## 10 noise level, FDnCNN model, NE
cfg_10_fdncnn_ne = deepcopy(cfg_10)
cfg_10_fdncnn_ne.model = "fdncnn"
cfg_10_fdncnn_ne.num_steps = 900_000
cfg_10_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)
## 10 noise level, FDnCNN model, SE
cfg_10_fdncnn_se = deepcopy(cfg_10)
cfg_10_fdncnn_se.model = "fdncnn"
cfg_10_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_10_fdncnn_se.loss_type = "l1"

## 10 noise level, FDnCNN model, O
cfg_10_fdncnn_o = deepcopy(cfg_10)
cfg_10_fdncnn_o.model = "fdncnn"
cfg_10_fdncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_10_fdncnn_o.loss_type = "l1"

## 10 noise level, DnCNN model, WNE
cfg_10_dncnn_wne = deepcopy(cfg_10)
cfg_10_dncnn_wne.model = "dncnn"
cfg_10_dncnn_wne.lr_halving_steps = int(1e5)
cfg_10_dncnn_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)

## 10 noise level, DnCNN model, O
cfg_10_dncnn_o = deepcopy(cfg_10)
cfg_10_dncnn_o.model = "dncnn"
cfg_10_dncnn_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_10_dncnn_o.loss_type = "l1"

# %% non-Gaussian experiments
# Input PSNR -> sigma (in [0,1]) -> sigma_8bit:
# 22 dB -> ~0.0794 -> ~20.2
# 25 dB -> ~0.0562 -> ~14.3
# 17 dB (Rayleigh, nonzero mean) -> scale down by sqrt(1 + 1.913^2) ≈ sqrt(4.659)
#     target sigma ≈ 0.0655 -> sigma_8bit ≈ 16.7

sigma22_laplace = 20
sigma25_uniform = 14
sigma17_rayleigh = 17  # approx 16.7 rounded to nearest integer

# Base configs for non-Gaussian experiments
cfg_laplace_22 = _with_noise(deepcopy(cfg), sigma22_laplace, "laplace")
cfg_uniform_25 = _with_noise(deepcopy(cfg), sigma25_uniform, "uniform")
cfg_rayleigh_17 = _with_noise(deepcopy(cfg), sigma17_rayleigh, "rayleigh")
cfg_rayleigh_17.psnr_eval_sigma_values = [
    s / 255.0 for s in (1, 2, 3, 5, 8, 12, 17, 25, 36, 50, 70, 95)
]

# Laplace @22 dB
cfg_laplace_22_fdncnn_ne = deepcopy(cfg_laplace_22)
cfg_laplace_22_fdncnn_ne.model = "fdncnn"
cfg_laplace_22_fdncnn_ne.num_steps = 900_000
cfg_laplace_22_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_laplace_22_fdncnn_ne.loss_type = "l1"

cfg_laplace_22_fdncnn_se = deepcopy(cfg_laplace_22)
cfg_laplace_22_fdncnn_se.model = "fdncnn"
cfg_laplace_22_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)

# Uniform @25 dB
cfg_uniform_25_fdncnn_ne = _with_noise(deepcopy(cfg), sigma25_uniform, "uniform")
cfg_uniform_25_fdncnn_ne.model = "fdncnn"
cfg_uniform_25_fdncnn_ne.num_steps = 900_000
cfg_uniform_25_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_uniform_25_fdncnn_ne.loss_type = "l1"

cfg_uniform_25_fdncnn_se = _with_noise(deepcopy(cfg), sigma25_uniform, "uniform")
cfg_uniform_25_fdncnn_se.model = "fdncnn"
cfg_uniform_25_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)

# Rayleigh @17 dB
cfg_rayleigh_17_fdncnn_ne = deepcopy(cfg_rayleigh_17)
cfg_rayleigh_17_fdncnn_ne.model = "fdncnn"
cfg_rayleigh_17_fdncnn_ne.num_steps = 900_000
cfg_rayleigh_17_fdncnn_ne.model_cfg = ModelConfig(
    model_mode="norm-equiv", pred_mode="direct", wrapper_mode="idem"
)
cfg_rayleigh_17_fdncnn_ne.loss_type = "l1"

cfg_rayleigh_17_fdncnn_wne = deepcopy(cfg_rayleigh_17)
cfg_rayleigh_17_fdncnn_wne.model = "fdncnn"
cfg_rayleigh_17_fdncnn_wne.num_steps = 900_000
cfg_rayleigh_17_fdncnn_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)
cfg_rayleigh_17_fdncnn_wne.lr_halving_steps = int(1e5)
cfg_rayleigh_17_fdncnn_wne.loss_type = "l1"

cfg_rayleigh_17_fdncnn_se = deepcopy(cfg_rayleigh_17)
cfg_rayleigh_17_fdncnn_se.model = "fdncnn"
cfg_rayleigh_17_fdncnn_se.model_cfg = ModelConfig(
    model_mode="scale-equiv", pred_mode="direct", wrapper_mode="idem"
)

# %% SwinIR experiments

# SwinIR (lite denoising) configs: ordinary (_o) and WNE (_wne)
swinir_patch_size = 64
swinir_batch_size = 32
cfg_50_swinir_o = deepcopy(cfg_50)
cfg_50_swinir_o.model = "swinir"
cfg_50_swinir_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_50_swinir_o.s_patch_size = swinir_patch_size
cfg_50_swinir_o.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_50_swinir_o)

cfg_25_swinir_o = deepcopy(cfg_25)
cfg_25_swinir_o.model = "swinir"
cfg_25_swinir_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_25_swinir_o.s_patch_size = swinir_patch_size
cfg_25_swinir_o.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_25_swinir_o)

cfg_10_swinir_o = deepcopy(cfg_10)
cfg_10_swinir_o.model = "swinir"
cfg_10_swinir_o.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="idem"
)
cfg_10_swinir_o.s_patch_size = swinir_patch_size
cfg_10_swinir_o.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_10_swinir_o)

cfg_50_swinir_wne = deepcopy(cfg_50)
cfg_50_swinir_wne.model = "swinir"
cfg_50_swinir_wne.lr_halving_steps = int(1e5)
cfg_50_swinir_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)
cfg_50_swinir_wne.s_patch_size = swinir_patch_size
cfg_50_swinir_wne.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_50_swinir_wne)

cfg_25_swinir_wne = deepcopy(cfg_25)
cfg_25_swinir_wne.model = "swinir"
cfg_25_swinir_wne.lr_halving_steps = int(1e5)
cfg_25_swinir_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)
cfg_25_swinir_wne.s_patch_size = swinir_patch_size
cfg_25_swinir_wne.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_25_swinir_wne)

cfg_10_swinir_wne = deepcopy(cfg_10)
cfg_10_swinir_wne.model = "swinir"
cfg_10_swinir_wne.lr_halving_steps = int(1e5)
cfg_10_swinir_wne.model_cfg = ModelConfig(
    model_mode="ordinary", pred_mode="direct", wrapper_mode="norm-equiv"
)
cfg_10_swinir_wne.s_patch_size = swinir_patch_size
cfg_10_swinir_wne.batch_size = swinir_batch_size
_set_s_samples_per_epoch(cfg_10_swinir_wne)
