import torch
from utils.analysis_tools import *

def compute_adam_preconditioner(optimizer):
    preconditioners = []
    
    for group in optimizer.param_groups:
        beta1, beta2 = group['betas']
        eps = group['eps']
        
        for p in group['params']:
            # print(p.grad)
            # if p.grad is None:
            #     continue
            
            state = optimizer.state[p]
            
            # 
            if len(state) == 0:
                continue
            
            t = state['step']
            exp_avg_sq = state['exp_avg_sq']
            
            beta1_t = beta1 ** (t + 1)
            beta2_t = beta2 ** (t + 1)
            
            #  v_{t+1} / (1 - beta2^(t+1))
            v_hat = exp_avg_sq / (1 - beta2_t)
            
            # 
            precond = (1 - beta1_t) * (torch.sqrt(v_hat) + eps * torch.ones_like(v_hat))
            # print(precond)
            
            
            preconditioners.append(precond)

    # 
    # inverse_merged_preconditioner = 1.0 / merged_preconditioner
    inverse_merged_preconditioner = [1 / precond for precond in preconditioners]


    return inverse_merged_preconditioner