import numpy as np
import logging
import torch
import functools
import traceback


def print_exc(function):
    """
    A decorator that wraps the passed in function and prints any exceptions.
    """
    @functools.wraps(function)
    def wrapper(*args, **kwargs):
        try:
            return function(*args, **kwargs)
        except Exception:
            traceback.print_exc()
            raise
    return wrapper

class GradientNoiseScale(object):

    def __init__(self, optimizer, batch_size, small_grad_sqr, big_grad_sqr):
        self.mixed_precision_scale = 1.0
        self.optimizer = optimizer
        self._small_batchsize = batch_size
        self._big_batchsize = self._small_batchsize * 2
        self.small_grad_sqr = small_grad_sqr
        self.big_grad_sqr = big_grad_sqr
    
    def _get_preconditioner(self):
        out = []
        for idx, group in enumerate(self._pre_optimizer.param_groups):
            pinvs = []
            for param in group["params"]:
                pinv = self._calculate_preconditioner(idx, param)
                pinvs.append(pinv)
            out.append(pinvs)
        return out

    def _calculate_preconditioner(self, idx, param):
        return torch.ones_like(param, memory_format=torch.preserve_format)

    def sqr_avg(self, val):
        """
        Current estimate of the squared l2-norm of the true gradient (sigma
        squared).

        Returns (float): Estimate of squared l2-norm.
        """
        return float(np.sum(np.maximum(val, 0.0)))

    def var_avg(self, val):
        """
        Current estimate of the trace of the covariance of the true gradient
        (mu squared).

        Returns (float): Estimate of trace of the covariance.
        """
        return float(np.sum(np.maximum(val, 1e-6)))
    
    # Current estimate of the squared l2-norm of the true gradients 
    def get_grad_sqr(self):
        small_grad_sqr = self.small_grad_sqr
        big_grad_sqr = self.big_grad_sqr
        logging.debug(f"1/(self._big_batchsize - self._small_batchsize) is: {1/(self._big_batchsize - self._small_batchsize)}")
        logging.debug(f"self._big_batchsize * big_grad_sqr - self._small_batchsize * small_grad_sqr is: {self._big_batchsize * big_grad_sqr - self._small_batchsize * small_grad_sqr}")
        
        grad_sqr = 1/(self._big_batchsize - self._small_batchsize) * \
            (self._big_batchsize * big_grad_sqr - self._small_batchsize * small_grad_sqr)
        grad_sqr = self.sqr_avg(abs(grad_sqr))    
        logging.debug(f"grad_sqr is: {grad_sqr}")
        return grad_sqr    

    # Estimate of the trace of the covariance
    def get_grad_var(self):
        small_grad_sqr = self.small_grad_sqr
        big_grad_sqr = self.big_grad_sqr
        grad_var = 1/(1/self._small_batchsize - 1/self._big_batchsize) * (small_grad_sqr - big_grad_sqr)
        grad_var = self.var_avg(grad_var)
        logging.debug(f"grad_var is: {grad_var}")
        return grad_var

    def get_gns(self):
        return self.get_grad_var() / self.get_grad_sqr()

    def get_efficiency(self):
        gns = self.get_gns()
        return (gns+self._small_batchsize)/(gns+self._big_batchsize)
    


def normsqr_groups(grads, pinvs):
    ret = []
    for group, pinv_group in zip(grads, pinvs):
        normsqr = [(g / pinv).pow(2).sum(dtype=torch.float64)
                for g, pinv in zip(group, pinv_group) if g is not None]
        ret.append(sum(normsqr).item() if normsqr else 0.0)
    return np.array(ret)   

def calculate_preconditioner(idx, param):
    return torch.ones_like(param, memory_format=torch.preserve_format)

def get_preconditioner(optimizer):
    out = []
    for idx, group in enumerate(optimizer.param_groups):
        pinvs = []
        for param in group["params"]:
            pinv = calculate_preconditioner(idx, param)
            pinvs.append(pinv)
        out.append(pinvs)
    return out

# anthor method to calculate the grad_sqr
def get_grad_sqr(optimizer):
    grads = []
    mixed_precision_scale = 1.0
    for group in optimizer.param_groups:
        grads.append([])
        for param in group["params"]:
            # print("param.grad is:", param.grad)
            if param.grad is None:
                grads[-1].append(None)
                continue
            grad = param.grad.detach().float()
            # print("grad is: ", grad)
            grads[-1].append(
                grad / mixed_precision_scale)
            # print("param.grad is: ", param.grad)   
    preconditioner = get_preconditioner(optimizer)
    grads_normsqr = normsqr_groups(grads, preconditioner)
    logging.debug(f'grad_sqr is: {float(np.sum(grads_normsqr))}')
    return float(np.sum(grads_normsqr))    