import torch
from torch.optim.optimizer import Optimizer, required, _use_grad_for_differentiable
import torch.autograd as ta
from collections import defaultdict
from typing import Callable, List, Optional, Union


def wrap_optimizer(optimizer, kappa = 0.7, gamma = 0.5):
    if isinstance(optimizer, Optimizer):
        class_type = [optimizer.__class__]
    else:
        class_type = [Optimizer, optimizer.__class__]
    class KFOptimizer(*class_type):
        def __init__(self, optimizer:Optimizer, kappa = 0.7, gamma = 0.5):
            '''
            # wrapping up the optimizer with
            optimizer = KFOptimizer(model.parameters(), optimizer, sigma_H, sigma_g)
            # before the first step of gradient accumulation:
            if t % acc_step == 0 and hasattr(optimizer, 'prestep'):
                optimizer.prestep()
            '''
            if gamma ==0:
                gamma = (1-kappa)/kappa
                self.compute_grad = False
            elif abs(gamma - (1-kappa)/kappa) <1e-3:
                gamma = (1-kappa)/kappa
                self.compute_grad = False
            else:
                self.scaling_factor = (gamma*kappa+kappa-1)/(1-kappa)
                self.compute_grad = True
            self.original_optimizer = optimizer
            self.original_optimizer.defaults['kappa'] = kappa
            self.original_optimizer.defaults['gamma'] = gamma


            # self._optimizer_step_pre_hooks
            # self._optimizer_step_post_hooks = OrderedDict()
            # self._optimizer_state_dict_pre_hooks = OrderedDict()
            # self._optimizer_state_dict_post_hooks = OrderedDict()
            # self._optimizer_load_state_dict_pre_hooks = OrderedDict()
            # self._optimizer_load_state_dict_post_hooks = OrderedDict()


            # self._warned_capturable_if_run_uncaptured = True


        def __getattr__(self, item):
            return getattr(self.original_optimizer, item)


        @property
        def param_groups(self) -> List[dict]:
            """
            Returns a list containing a dictionary of all parameters managed by the optimizer.
            """
            return self.original_optimizer.param_groups


        @param_groups.setter
        def param_groups(self, param_groups: List[dict]):
            """
            Updates the param_groups of the optimizer.
            """
            self.original_optimizer.param_groups = param_groups


        @property
        def state(self) -> defaultdict:
            """
            Returns a dictionary holding current optimization state.
            """
            return self.original_optimizer.state


        @state.setter
        def state(self, state: defaultdict):
            """
            Updates the state of the optimizer.
            """
            self.original_optimizer.state = state


        @property
        def defaults(self) -> dict:
            """
            Returns a dictionary containing default values for optimization.
            """
            return self.original_optimizer.defaults


        @defaults.setter
        def defaults(self, defaults: dict):
            """
            Updates the defaults of the optimizer.
            """
            self.original_optimizer.defaults = defaults


        def prestep(self, closure=required):
            loss = None
            gamma = self.defaults['gamma']
            # for group in self.param_groups:
            #     gamma = group['gamma']
            #     break
            if self.compute_grad:
                with torch.enable_grad():
                    loss = closure() # compute grad
            # totoal_grad = 0
            with torch.no_grad():
                for group in self.param_groups:
                    # gamma = group['gamma']
                    for p in group['params']:
                        state = self.state[p]
                        if 'kf_d_t' not in state:
                            continue
                        # perturb
                        p.data.add_(state['kf_d_t'], alpha = gamma)
                        if self.compute_grad:
                            if hasattr(p, 'private_grad'):
                                p.private_grad.mul_(self.scaling_factor)
                            elif p.grad is not None:
                                p.grad.mul_(self.scaling_factor)
                            else:
                                raise RuntimeError("Must have either grad or private_grad!")
            with torch.enable_grad():
                if self.compute_grad:
                    closure()
                else:
                    loss = closure()
            with torch.no_grad():
                for group in self.param_groups:
                    # gamma = group['gamma']
                    for p in group['params']:
                        state = self.state[p]
                        if 'kf_d_t' not in state:
                            continue
                        # perturb back
                        p.data.add_(state['kf_d_t'], alpha = -gamma)
                        if self.compute_grad:
                            if hasattr(p, 'private_grad'):
                                p.private_grad.div_(self.scaling_factor)
                            elif p.grad is not None:
                                p.grad.div_(self.scaling_factor)
            return loss
                
        @_use_grad_for_differentiable
        def step(self, closure=None):
            """Performs a single optimization step.
            Arguments:
                closure (callable, optional): A closure that reevaluates the model
                    and returns the loss.
            """
            scaling_factor = 0.0
            kappa = self.defaults['kappa']
            tmp_states = []
            first_step = False
            for group in self.param_groups:
                # kappa = group['kappa']
                for p in group['params']:
                    has_private_grad = False
                    if hasattr(p, 'private_grad'):
                        grad = p.private_grad
                        has_private_grad = True
                    elif p.grad is not None:
                        grad = p.grad
                    else:
                        continue
                    if self.compute_grad:
                        grad.div_(1+1/self.scaling_factor)
                    state = self.state[p]
                    if 'kf_d_t' not in state:
                        state = dict()
                        first_step = True
                        state['kf_d_t'] = torch.zeros_like(p.data).to(p.data)
                        state['kf_m_t'] = grad.clone().to(p.data)
                    state['kf_m_t'].lerp_(grad, weight = 1-kappa)
                    if has_private_grad:
                        p.private_grad = state['kf_m_t'].clone().to(p.data)
                    else:
                        p.grad = state['kf_m_t'].clone().to(p.data)
                        scaling_factor += p.grad.norm().pow(2)
                    state['kf_d_t'] = -p.data.clone().to(p.data)
                    if first_step:
                        tmp_states.append(state)
            if scaling_factor > 0 and not has_private_grad:
                scaling_factor = scaling_factor.sqrt()
                for group in self.param_groups:
                    for p in group['params']:
                        if p.grad is not None:
                            p.grad.div_(scaling_factor)
            loss = self.original_optimizer.step(closure)
            for group in self.param_groups:
                for p in group['params']:
                    if p.grad is not None:
                        if first_step:
                            tmp_state = tmp_states.pop(0)
                            self.state[p]['kf_d_t'] = tmp_state['kf_d_t']
                            self.state[p]['kf_m_t'] = tmp_state['kf_m_t']
                            del tmp_state
                        self.state[p]['kf_d_t'].add_(p.data, alpha = 1)
            return loss
        def __repr__(self):
            return self.original_optimizer.__repr__()


        def state_dict(self):
            return self.original_optimizer.state_dict()


        def load_state_dict(self, state_dict) -> None:
            self.original_optimizer.load_state_dict(state_dict)
    return KFOptimizer(optimizer, kappa=kappa, gamma=gamma)
