from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Literal, Tuple, overload

import jax.numpy as jnp
import optax
from optax import Schedule

from egxc.training.utils import scale
from egxc.training.utils.apply_if_finite import apply_if_finite
from egxc.training.utils.fromage import fromage
from egxc.training.utils.lookahead import LookaheadConfig, wrap_with_lookahead
from egxc.training.utils.metropolis import MetropolisTrainingStabilizerConfig
from egxc.utils.typing import Bool, PyTree


@dataclass
class ScheduleConfig:
    base_rate: float
    min_rate: float
    warmup_steps: int
    decay_steps: int
    warmup_schedule: Literal['linear', 'quadratic'] | None
    decay_schedule: Literal['linear', 'cosine', 'inverse_time_decay'] | None


def get_schedule(config: ScheduleConfig) -> Schedule:
    base_rate = 1  # scaling is handled manually

    def get_warmup_schedule(config: ScheduleConfig) -> Schedule:
        if config.warmup_schedule is None:
            return optax.constant_schedule(base_rate)
        if config.warmup_schedule == 'linear':
            return optax.linear_schedule(config.min_rate, base_rate, config.warmup_steps)
        if config.warmup_schedule == 'quadratic':
            return optax.polynomial_schedule(
                config.min_rate, base_rate, 2.0, config.warmup_steps
            )
        else:
            raise ValueError(f'Invalid warmup schedule: {config.warmup_schedule}')

    def get_decay_schedule(config: ScheduleConfig) -> Schedule:
        if config.decay_schedule is None:
            return optax.constant_schedule(base_rate)
        elif config.decay_schedule == 'linear':
            return optax.linear_schedule(base_rate, config.min_rate, config.decay_steps)
        elif config.decay_schedule == 'cosine':
            alpha = config.min_rate / base_rate
            return optax.cosine_decay_schedule(base_rate, config.decay_steps, alpha)
        elif config.decay_schedule == 'inverse_time_decay':
            assert config.min_rate == 0, (
                'learning rate decreases to 0 for inverse time decay'
            )

            def inverse_time_decay(step):
                out = base_rate / (1 + step / config.decay_steps)
                return jnp.array(out, dtype=jnp.float32)

            return inverse_time_decay
        else:
            raise ValueError(f'Invalid decay schedule: {config.decay_schedule}')

    warmup_schedule = get_warmup_schedule(config)
    decay_schedule = get_decay_schedule(config)
    return optax.join_schedules(
        [warmup_schedule, decay_schedule],
        boundaries=(config.warmup_steps,),
    )


@dataclass
class PlateauConfig:
    factor: float
    patience: int
    cooldown: int
    accumulation_size: int
    min_scale: float
    min_relative_improvement: float


def get_plateau_schedule(config: PlateauConfig) -> optax.GradientTransformationExtraArgs:
    return optax.contrib.reduce_on_plateau(
        factor=config.factor,
        patience=config.patience,
        cooldown=config.cooldown,
        accumulation_size=config.accumulation_size,
        min_scale=config.min_scale,
        rtol=config.min_relative_improvement,
    )


@dataclass
class OptConfig:
    name: Literal['adam', 'adabelief', 'fromage', 'polyak', 'radam', 'lamb', 'muon']
    weight_decay: float
    decay_only_graph_readout: bool
    schedule: ScheduleConfig
    plateau_handling: PlateauConfig | None
    metropolis_stabilizer: MetropolisTrainingStabilizerConfig
    lookahead: LookaheadConfig | None
    apply_every: int
    clip_grad_max_norm: float | None
    skip_nans: int
    additional_params: Dict[str, Any] | None  # Additional parameters for the optimizer
    epochs: int
    ema_decay: float
    early_stopping_patience: int
    early_stopping_min_relative_improvement: float
    restart_epochs: List[int]  # delta epochs between restarts starting from left
    restart_lr_scales: List[
        float
    ]  # LR scaling factors for each restart starting from left

    @classmethod
    def create(cls, **kwargs) -> 'OptConfig':
        plateau_conf = kwargs.get('plateau_handling', None)
        if plateau_conf is not None:
            plateau_conf = PlateauConfig(**plateau_conf)

        stabalizer_dict = kwargs.get('metropolis_stabilizer', None)
        if stabalizer_dict is not None:
            stabilizer = MetropolisTrainingStabilizerConfig(**stabalizer_dict)
        else:
            stabilizer = MetropolisTrainingStabilizerConfig.turn_off()

        lookahead_conf = kwargs.get('lookahead', None)
        if lookahead_conf is not None:
            lookahead_conf = LookaheadConfig(**lookahead_conf)

        restart_epochs = kwargs.get('restart_epochs', [])
        restart_lr_scales = kwargs.get('restart_lr_scales', [])
        assert len(restart_epochs) == len(restart_lr_scales), (
            'restart_epochs and restart_lr_scales must have the same length'
        )
        return cls(
            name=kwargs['name'],
            weight_decay=kwargs['weight_decay'],
            decay_only_graph_readout=kwargs['decay_only_graph_readout'],
            schedule=ScheduleConfig(**kwargs['schedule']),
            plateau_handling=plateau_conf,  # type: ignore
            metropolis_stabilizer=stabilizer,
            lookahead=lookahead_conf,
            apply_every=kwargs['apply_every'],
            clip_grad_max_norm=kwargs['clip_grad_max_norm'],
            skip_nans=kwargs['skip_nans'],
            additional_params=kwargs['additional_params'],  # type: ignore
            epochs=kwargs['epochs'],
            ema_decay=kwargs['ema_decay'],
            early_stopping_patience=kwargs['early_stopping_patience'],
            early_stopping_min_relative_improvement=kwargs[
                'early_stopping_min_relative_improvement'
            ],
            restart_epochs=restart_epochs,
            restart_lr_scales=restart_lr_scales,
        )

    def create_restart_config(self) -> 'OptConfig':
        assert len(self.restart_epochs) > 0, 'no restarts remaining'
        restart_epochs = deepcopy(self.restart_epochs)
        current_epoch = restart_epochs.pop(0)
        restart_lr_scales = deepcopy(self.restart_lr_scales)
        lr_scale = restart_lr_scales.pop(0)
        kwargs = asdict(self)
        kwargs['epochs'] -= current_epoch
        kwargs['schedule']['base_rate'] = self.schedule.base_rate * lr_scale
        kwargs['restart_epochs'] = restart_epochs
        kwargs['restart_lr_scales'] = restart_lr_scales
        return self.create(**kwargs)

    @property
    def with_restarts(self) -> bool:
        return len(self.restart_epochs) > 0


Optimizer = optax.GradientTransformation | optax.GradientTransformationExtraArgs


@overload
def get_optimizer(
    config: OptConfig,
    manual_scaling: Literal[False] = False,
    graph_readout_decay_mask: PyTree[Bool] | None = None,
) -> Optimizer: ...
@overload
def get_optimizer(
    config: OptConfig,
    manual_scaling: Literal[True],
    graph_readout_decay_mask: PyTree[Bool] | None = None,
) -> Tuple[Optimizer, Optimizer]: ...
def get_optimizer(
    config: OptConfig,
    manual_scaling: bool = False,
    graph_readout_decay_mask: PyTree[Bool] | None = None,
) -> Optimizer | Tuple[Optimizer, Optimizer]:
    """
    Get the optimizer and its state from the configuration dictionary.
    If manual_scaling is True, the optimizer is returned as a tuple containing a normalized main
    optimizer and the post-processing transform, such that the learning rate can be manually applied
    to the updates. This avoids recompilation of the model for learning rate ablations.
    """
    schedule = get_schedule(config.schedule)

    gradient_transforms = []
    if config.name == 'adam':
        adam_params = config.additional_params or {}
        optimizer = optax.adamw(
            schedule,
            weight_decay=config.weight_decay,
            b1=adam_params.get('b1', 0.9),
            b2=adam_params.get('b2', 0.999),
            mask=graph_readout_decay_mask,
        )
        gradient_transforms.append(optimizer)
    elif config.name == 'lamb':
        optimizer = optax.lamb(learning_rate=schedule, weight_decay=config.weight_decay)
        gradient_transforms.append(optimizer)
    elif config.name == 'muon':
        muon_params = config.additional_params or {}
        optimizer = optax.contrib.muon(
            learning_rate=schedule,
            adam_b1=muon_params.get('b1', 0.9),
            adam_b2=muon_params.get('b2', 0.999),
            weight_decay=config.weight_decay,
            weight_decay_mask=graph_readout_decay_mask,
        )
        gradient_transforms.append(optimizer)
    else:
        if config.name == 'adabelief':
            optimizer = optax.adabelief(learning_rate=schedule)
            gradient_transforms.append(optimizer)
        elif config.name == 'fromage':
            optimizer = fromage(learning_rate=schedule)
            gradient_transforms.append(optimizer)
        elif config.name == 'polyak':  # TODO: fix
            optimizer = optax.polyak_sgd(
                scaling=schedule,
            )
            gradient_transforms.append(optimizer)
        elif config.name == 'radam':
            optimizer = optax.radam(learning_rate=schedule)
            gradient_transforms.append(optimizer)
        else:
            raise ValueError(f'Invalid optimizer: {config.name}')

        if config.weight_decay > 0.0:
            gradient_transforms.append(
                optax.add_decayed_weights(config.weight_decay, graph_readout_decay_mask)
            )

    if config.plateau_handling is not None:
        gradient_transforms.append(get_plateau_schedule(config.plateau_handling))

    post_process_transforms = []
    if config.clip_grad_max_norm is not None:
        post_process_transforms.append(
            optax.clip_by_global_norm(config.clip_grad_max_norm)
        )

    if config.apply_every > 1:
        post_process_transforms.append(optax.apply_every(config.apply_every))

    if not manual_scaling:
        gradient_transforms.append(scale(config.schedule.base_rate))
        out = optax.chain(*gradient_transforms, *post_process_transforms)
        # Wrap with lookahead before apply_if_finite so both fast and slow weights are protected
        if config.lookahead is not None:
            out = wrap_with_lookahead(out, config.lookahead)
        if config.skip_nans > 0:  # make resilient to occasional NaNs
            out = apply_if_finite(out, max_consecutive_errors=config.skip_nans)
    else:  # manual_scaling
        opt = optax.chain(*gradient_transforms)
        # Wrap with lookahead before apply_if_finite so both fast and slow weights are protected
        if config.lookahead is not None:
            opt = wrap_with_lookahead(opt, config.lookahead)
        if post_process_transforms != []:
            post_processing = optax.chain(*post_process_transforms)
        else:
            post_processing = optax.identity()
        if config.skip_nans > 0:  # make resilient to occasional NaNs
            opt = apply_if_finite(opt, max_consecutive_errors=config.skip_nans)
            post_processing = apply_if_finite(
                post_processing, max_consecutive_errors=config.skip_nans
            )
        out = opt, post_processing
    return out
