# -*- coding: utf-8 -*-
"""
Created on Fri Jun 24 00:33:18 2022

@author: 10732
"""

import torch
import torch.optim as optim
import numpy as np
from torch import add, sub, mul, div, matmul
# from .optimizer import Optimizer, required

class ADOM(optim.SGD):
    def __init__(self, params, args, L_, λ_min, z_g_k, z_f_k, m_k, Δ_k, κ_, lr=0.01):
        self.args = args
        super().__init__(params, args.lr)
        
        self.world_size = args.world_size
        if args.dataset == 'mnist' or args.dataset == 'cifar10':
            self.d = 5
        self.L = L_
        self.L = (self.L / 4.) * κ_ / (κ_ - 1.)
        self.μ = self.L / κ_
        
        
        self.α = 0.5 / self.L
        self.η = 2. * λ_min * np.sqrt(self.L * self.μ) / 7.
        self.θ = self.μ
        self.σ = 1.0
        self.τ = λ_min * np.sqrt(self.μ / self.L) / 7.
        
        self.z_g_k = z_g_k
        self.z_f_k = z_f_k
        self.m_k = m_k
        self.Δ_k = Δ_k
    
    def __setstate__(self, state):
        super(ADOM, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, weight_matrix, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            # weight_decay = group['weight_decay']
            # momentum = group['momentum']
            # dampening = group['dampening']
            # nesterov = group['nesterov']

            for idx, p in enumerate(group['params']):
                if p.grad is None:
                    continue
                d_p = p.grad
                g_k = d_p.clone()
                
                # z_g_k = τ * z_k + (1. - τ) * z_f_k     
                self.z_g_k[idx] = add(mul(p, self.τ), mul(self.z_f_k[idx], (1 - self.τ)))  
                
                # for t in range(self.args.inner_iter):
                #     # g_k -= (oracle.dist_grad(g_k) - z_g_k) / oracle.L
                #     g_k.sub_(div(sub(g_k, self.z_g_k[idx]), self.L))
                
                # Δ_k = σ * W_list[k % number_of_graphs] @ (m_k - η * g_k)
                self.Δ_k[idx] = matmul(mul(weight_matrix, self.σ), sub(self.m_k[idx], mul(g_k, self.η)))
                # m_k -= η * g_k + Δ_k
                self.m_k[idx].sub_(add(mul(g_k, self.η), self.Δ_k[idx]))
                # z_k += η * α * (z_g_k - z_k) + Δ_k
                p.add_(add(mul(sub(self.z_g_k[idx], p), self.η * self.α), self.Δ_k[idx]))
                
                # p = add(p, (add(mul(sub(self.z_g_k[idx], p), self.η * self.α), self.Δ_k[idx])))
                # z_f_k = z_g_k - θ * W_list[k % number_of_graphs] @ g_k
                self.z_f_k[idx] = sub(self.z_g_k[idx], matmul(mul(weight_matrix, self.θ), g_k))
                
        