# -*- coding: utf-8 -*-
"""
Created on Fri Jun 24 00:33:18 2022

@author: 10732
"""

import torch
import torch.optim as optim
import numpy as np
from torch import add, sub, mul, div, matmul

class mudag(optim.SGD):
    def __init__(self, params, args, L_, λ_min, x_k1, x_k2, y_k0, y_k1, y_k2, g_k0, g_k1, κ_):
        self.args = args
        super().__init__(params, args.lr)
        
        self.world_size = args.world_size
        if args.dataset == 'mnist' or args.dataset == 'cifar10':
            self.d = 10
        self.L = L_
        self.μ = self.L / κ_
        
        self.x_k1 = x_k1
        self.x_k2 = x_k2
        self.y_k0 = y_k0
        self.y_k1 = y_k1
        self.y_k2 = y_k2
        self.g_k0 = g_k0
        self.g_k1 = g_k1
        
        self.η = 1. / self.L
        self.α = np.sqrt(self.μ / self.L)
        self.inner_iter = (1 + int(1. / λ_min * np.log(2304. * np.power(self.L / self.μ, 1.5))))
    
    def __setstate__(self, state):
        super(mudag, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, weight_matrix, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            # weight_decay = group['weight_decay']
            # momentum = group['momentum']
            # dampening = group['dampening']
            # nesterov = group['nesterov']
            for idx, p in enumerate(group['params']):
                if p.grad is None:
                    print('p.grad is None')
                    continue
                d_p = p.grad
                # g_k1 = oracle.dist_grad(y_k1)
                self.y_k1[idx] = p.clone()
                self.g_k1[idx] = d_p.clone()
                # x_k2 = y_k1 + x_k1 - y_k0 - η * (g_k1 - g_k0)
                self.x_k2[idx] = sub(sub(add(self.y_k1[idx], self.x_k1[idx]), self.y_k0[idx]), 
                                mul(sub(self.g_k1[idx], self.g_k0[idx]), self.η))
                
                # for t in range(inner_iter):
                #     # x_k2 -= W_list[(k * inner_iter + t) % number_of_graphs] @ x_k2
                #     self.x_k2[idx].sub_(matmul(weight_matrix, self.x_k2[idx]))
                    
                # y_k2 = x_k2 + (x_k2 - x_k1) * (1. - α) / (1. + α)
                self.y_k2[idx] = add(self.x_k2[idx], 
                                     mul(sub(self.x_k2[idx], self.x_k1[idx]), 
                                          (1. - self.α) / (1. + self.α)))
                
                p.data = self.y_k2[idx].clone()
                
                self.y_k0[idx] = self.y_k1[idx].clone()
                self.y_k1[idx] = self.y_k2[idx].clone()
                self.x_k1[idx] = self.x_k2[idx].clone()
                self.g_k0[idx] = self.g_k1[idx].clone()
                
                
    
  
                


        