from typing import Dict, List, Tuple, Union, Iterable
import torch
from torch.optim.optimizer import Optimizer
from omegaconf import DictConfig

__all__ = ["PadoOptimizer", "PadoOptimizerList"]


class PadoOptimizer(Optimizer):
    """
    Base class for Optimizers in Pado framework.
    """

    def __init__(self, params, defaults: Dict) -> None:
        super().__init__(params, defaults)

    def current_lrs(self) -> Tuple[float, ...]:
        lrs = []
        for param_group in self.param_groups:
            lrs.append(param_group["lr"])
        return tuple(lrs)

    @staticmethod
    def _compute_mean_over_dim(g: torch.Tensor, dim: int = 0) -> torch.Tensor:
        if dim == 0:
            mean_dim = list(range(1, g.ndim))
        else:
            mean_dim = list(range(dim)) + list(range(dim + 1, g.ndim))
        return torch.mean(g, dim=mean_dim, keepdim=True)

    @classmethod
    def from_config(cls, cfg: DictConfig, params):
        raise NotImplementedError


class PadoOptimizerList(object):
    """
    Container to hold multiple PadoOptimizers.
    """

    def __init__(self, optimizers: Union[PadoOptimizer, Iterable[PadoOptimizer]]):
        if isinstance(optimizers, PadoOptimizer):
            optimizers = [optimizers]
        self.optimizers = list(optimizers)

    def __len__(self) -> int:
        return len(self.optimizers)

    def current_lrs(self) -> List[Tuple[float, ...]]:
        lrs = []
        for opt in self.optimizers:
            lrs.append(opt.current_lrs())
        return list(lrs)

    def step(self, closure=None) -> None:
        for opt in self.optimizers:
            opt.step(closure=closure)

    def state_dict(self) -> List:
        states = []
        for opt in self.optimizers:
            states.append(opt.state_dict())
        return states

    def load_state_dict(self, state_dict: List) -> None:
        # state_dict should be loaded as same order as saved.
        if len(state_dict) != len(self):
            raise ValueError(f"#optimizers in state dict mismatch, {len(self)} vs {len(state_dict)}.")
        for state, opt in zip(state_dict, self.optimizers):
            opt.load_state_dict(state)

    def zero_grad(self, set_to_none: bool = True) -> None:
        for opt in self.optimizers:
            opt.zero_grad(set_to_none=set_to_none)
