import os
import sys
import time
import math

import torch
import numpy as np
import torch.optim

class AdapSGD(torch.optim.Optimizer):


    def __init__(self, params, lr=0.1, beta=0.95, epsilon=0.01,
                 weight_decay=0):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
#         if momentum < 0.0:
#             raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, beta=0.95, epsilon=0.01,
                        weight_decay=weight_decay)

        super(AdapSGD, self).__init__(params, defaults)
        self.params = []
        for group in self.param_groups:
            for p in group['params']:
                self.params.append(p)
        self.mhatsq = 0
        self.beta = beta
        self.epsilon = epsilon
        


    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        total_norm = torch.clone(torch.norm(torch.stack([torch.norm(p.grad.detach()) 
                                                         for p in self.params 
                                                         if p.grad is not None]))).detach()
        if self.mhatsq == 0:
            self.mhatsq = total_norm**2
        
        
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
#             momentum = group['momentum']
#             dampening = group['dampening']
#             nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)

                step = (-group['lr']/(self.mhatsq**0.5 + self.epsilon)).detach()
                p.add_(d_p, alpha=step)
        
        self.mhatsq = self.beta * self.mhatsq + (1 - self.beta) * total_norm**2
        return loss


    

def compute_noise(stoc_grads, true_grads):
    total_noise_sq = 0 
    total_grad_sq = 0
    for k in stoc_grads.keys():
        total_noise_sq += (stoc_grads[k]- true_grads[k]).norm(2).item() ** 2
        total_grad_sq += stoc_grads[k].norm(2).item() ** 2
    return total_noise_sq, total_grad_sq

def compute_norm(grads):
    total_grad_sq = 0
    for k in grads.keys():
        total_grad_sq += grads[k].norm(2).item() ** 2
    return total_grad_sq

def clone_grad(net, true_grads):
    for name, param in net.named_parameters():
        if param.grad is None:
                    continue
        true_grads[name] = torch.clone(param.grad.data).detach()
        
        

def coord_noise(stoc_grads, true_grads):
    coordnoise = []
    for k in stoc_grads.keys():
        coordnoise.extend((stoc_grads[k]-
                           true_grads[k]).cpu().numpy().flatten().tolist()[:])
    coordnoise = np.array(coordnoise)
    return coordnoise

def repackage_hidden(h):
    """Wraps hidden states in new Tensors,
    to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


def batchify(data, bsz, args):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    if args.cuda:
        data = data.cuda()
    return data


def get_batch(source, i, args, seq_len=None, evaluation=False):
    seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target










