import math
from dataclasses import dataclass, field
from typing import List, Optional, Tuple
import numpy as np


@dataclass
class ExperimentConfig:
    omega_deg: float = 270.0
    bc_type: str = "DD"
    method: str = "constraint"
    seed: int = 0

    K: int = 6
    mu_min: float = 0.1
    mu_max: float = 3.0

    total_steps: int = 20000
    warmup_steps: int = 3000
    constraint_ramp_steps: int = 5000

    n_arc: int = 512
    n_edge: int = 512
    r_min: float = 1e-4
    r_sample_power: float = 2.0

    lr_coeff: float = 1e-2
    lr_mu_warmup: float = 1e-4
    lr_mu_main: float = 5e-4

    w_arc: float = 50.0
    w_edge: float = 20.0
    w_constraint_max: float = 100.0
    w_l1: float = 1e-4
    w_mode_prefer_small: float = 0.01

    grad_clip_coeff: float = 1.0
    grad_clip_mu: float = 0.2

    log_every: int = 500

    device: str = "cuda"

    @property
    def omega(self) -> float:
        return math.radians(self.omega_deg)

    @property
    def true_mu(self) -> float:
        if self.bc_type in ["DD", "NN"]:
            return math.pi / self.omega
        else:
            return math.pi / (2 * self.omega)

    def get_init_mus(self) -> List[float]:
        true_mu = self.true_mu

        if self.bc_type in ["DN", "ND"] and self.omega > math.pi:
            init_mus = [
                0.15,
                0.3,
                0.5,
                0.8,
                1.2,
                1.8,
            ]
        elif self.omega > math.pi:
            init_mus = [
                0.3,
                0.5,
                0.7,
                1.0,
                1.4,
                2.0,
            ]
        else:
            init_mus = [
                0.5,
                0.8,
                1.2,
                1.6,
                2.0,
                2.5,
            ]

        return init_mus[: self.K]


@dataclass
class SweepConfig:
    n_omega: int = 30
    omega_min_deg: float = 90.0
    omega_max_deg: float = 330.0

    bc_types: List[str] = field(default_factory=lambda: ["DD", "NN", "DN", "ND"])

    methods: List[str] = field(default_factory=lambda: ["naive", "constraint"])

    seeds: List[int] = field(default_factory=lambda: [0, 1, 2])

    output_dir: str = "results"
    csv_path: str = "exp7.csv"

    def get_omega_grid(self) -> np.ndarray:
        return np.linspace(self.omega_min_deg, self.omega_max_deg, self.n_omega)

    def get_all_configs(self) -> List[ExperimentConfig]:
        configs = []
        omega_grid = self.get_omega_grid()

        for omega_deg in omega_grid:
            for bc_type in self.bc_types:
                for method in self.methods:
                    for seed in self.seeds:
                        config = ExperimentConfig(
                            omega_deg=float(omega_deg),
                            bc_type=bc_type,
                            method=method,
                            seed=seed,
                        )
                        configs.append(config)

        return configs

    @property
    def total_experiments(self) -> int:
        return self.n_omega * len(self.bc_types) * len(self.methods) * len(self.seeds)


def get_constraint_func(bc_type: str):
    if bc_type in ["DD", "NN"]:
        return lambda mu, omega: np.sin(mu * omega) ** 2
    else:
        return lambda mu, omega: np.cos(mu * omega) ** 2


def get_angular_basis(bc_type: str):
    if bc_type in ["DD", "DN"]:
        return "sin"
    else:
        return "cos"


def get_second_edge_bc(bc_type: str):
    if bc_type in ["DD", "ND"]:
        return "dirichlet"
    else:
        return "neumann"
