import torch 
import torch.nn as nn 


class DVOptimizer(torch.optim.Optimizer): 
      
    # Init Method: 
    def __init__(self, params, lr=1e-2, momentum=0.5, num_iters=0, warmup_epochs=None): 
        super(DVOptimizer, self).__init__(params, defaults={'lr': lr}) 
        self.momentum = momentum
        self.num_iters = num_iters
        if warmup_epochs is None:
            self.warmup_epochs = num_iters // 2
        else:
            self.warmup_epochs = warmup_epochs
        self.total_weight = sum([1/i for i in range(1, num_iters+1)])
        self.state = dict()

    def step(self, player=None, step=2):
        for group in self.param_groups: 
            for p in group['params']: 
                if p.grad is None or p.requires_grad is False:
                    continue
                if p not in self.state:
                    self.state[p] = {}
                if player is not None:
                    dict_key = f"mom_{player}"
                    if dict_key not in self.state[p]:
                        self.state[p][dict_key] = p.grad.data
                    self.state[p][dict_key] = (step-1)/step * self.state[p][dict_key] + p.grad.data / step
                    if step >= self.warmup_epochs:
                        p.data -= group['lr'] * self.state[p][dict_key]
                else:
                    p.data -= group['lr'] * p.grad.data
