import torch
from spodnet.GLAD.notebooks.torch_sqrtm import MatrixSquareRoot

torch_sqrtm = MatrixSquareRoot.apply


def batch_matrix_sqrt(A):
    # A should be PSD
    # if shape of A is 2D, i.e. a single matrix
    if len(A.shape) == 2:
        return torch_sqrtm(A)
    else:
        n = A.shape[0]
        sqrtm_torch = torch.zeros(A.shape).type_as(A)
        for i in range(n):
            sqrtm_torch[i], Y, I, Z = torch_sqrtm(A[i])
        return sqrtm_torch


def get_frobenius_norm(A, single=False):
    if single:
        return torch.sum(A**2)
    return torch.mean(torch.sum(A**2, (1, 2)))


def glad(Sb, theta_true, model, my_args, criterion_graph, rho_vals_list=None):

    # NOTE: replace my_args with args parser

    D, INIT_DIAG, lambda_init, L = my_args

    USE_CUDA = False  # NOTE : Please remove this flag when running on GPUs

    # if batch is 1, then reshaping Sb
    if len(Sb.shape) == 2:
        Sb = Sb.reshape(1, Sb.shape[0], Sb.shape[1])

    # Initializing the theta
    if INIT_DIAG == 1:
        # print(' extract batchwise diagonals, add offset and take inverse')
        batch_diags = 1/(torch.diagonal(Sb, offset=0, dim1=-
                         2, dim2=-1) + model.theta_init_offset)
        theta_init = torch.diag_embed(batch_diags)

    else:
        # print('***************** (S+theta_offset*I)^-1 is used')

        theta_init = torch.inverse(
            Sb+model.theta_init_offset * torch.eye(D).expand_as(Sb).type_as(Sb))

    theta_pred = theta_init  # [ridx]
    identity_mat = torch.eye(Sb.shape[-1]).expand_as(Sb)
    # diagonal mask
#    mask = torch.eye(Sb.shape[-1], Sb.shape[-1]).byte()
#    dim = Sb.shape[-1]
#    mask1 = torch.ones(dim, dim) - torch.eye(dim, dim)
    if USE_CUDA == True:
        identity_mat = identity_mat.cuda()
#        mask = mask.cuda()
#        mask1 = mask1.cuda()

    zero = torch.Tensor([0])
    dtype = torch.FloatTensor
    if USE_CUDA == True:
        zero = zero.cuda()
        dtype = torch.cuda.FloatTensor

    lambda_k = model.lambda_forward(zero + lambda_init, zero,  k=0)
    loss_glad = torch.Tensor([0]).type(dtype)

    Theta_list = []
    Z_list = []

    for k in range(L):

        # GLAD CELL

        b = 1.0/lambda_k * Sb - theta_pred

        b2_4ac = torch.matmul(b.transpose(-1, -2), b) + \
            4.0/lambda_k * identity_mat

        sqrt_term = batch_matrix_sqrt(b2_4ac)

        theta_k1 = 1.0/2*(-1*b+sqrt_term)  # theta_k1 is their Theta

        theta_pred, rho_vals_list = model.eta_forward(
            theta_k1, Sb, k, theta_pred, rho_vals_list=rho_vals_list)  # theta_pred is their Z

        # update the lambda
        lambda_k = model.lambda_forward(torch.Tensor(
            [get_frobenius_norm(theta_pred-theta_k1)]).type(dtype), lambda_k, k)

        # if criterion_graph!= None:
        # loss_glad += criterion_graph(theta_pred*mask1, theta_true*mask1)/L
        # loss_glad += criterion_graph(theta_pred.masked_fill_(mask, 0), theta_true.masked_fill_(mask, 0))/L

        Theta = theta_k1.clone()

        Z = theta_pred.clone()

        Theta_list.append(Theta)
        Z_list.append(Z)

        loss_glad += criterion_graph(theta_pred, theta_true)/L

    # try:
    #     assert torch.all(torch.linalg.eigvalsh(Z) > 0.)
    # except:
    #     print(f"SPDness broke.")
    #     breakpoint()

    # return theta_pred, loss_glad, rho_vals_list
    return theta_k1, theta_pred, loss_glad, rho_vals_list, Theta_list, Z_list
