"""
Copied and modified from https://github.com/CW-Huang/CP-Flow/blob/main/lib/logdet_estimators.py
"""
import torch
import numpy as np
from torch.utils.data import Dataset


def load_omniglot():
    
    X_train = np.load("data/omniglot_npy/omniglot_xs_train.npy")
    Y_train = np.load("data/omniglot_npy/omniglot_ys_train.npy")
    X_test = np.load("data/omniglot_npy/omniglot_xs_test.npy")
    Y_test = np.load("data/omniglot_npy/omniglot_ys_test.npy")

    X_train = torch.from_numpy(X_train)
    Y_train = torch.from_numpy(Y_train)
    X_test = torch.from_numpy(X_test)
    Y_test = torch.from_numpy(Y_test)

    return X_train, Y_train, X_test, Y_test


class OmniglotDataset(Dataset):
    def __init__(self, features, labels):
        self.labels = labels
        self.features = features

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels[idx, :]
        feature = self.features[idx, :].unsqueeze(1)

        return feature, label
    

def batch_dot_product(a, b):
    return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).squeeze(2)


def batch_matrix_vector_product(A, x):
    return torch.bmm(A, x.unsqueeze(2)).squeeze(2)


def conjugate_gradient(
        W,
        V,
        b,
        matrix_vector_product_function,
        dot_product_function,
        iters=10,
        rtol=1e-3,
        atol=0.0,
    ):
    """ Solves WLW^T v using m iterations of conjugate gradient.
    """
    # initialization
    x = b.clone().detach()
    Ax = matrix_vector_product_function(W, V, x)
    r = b - Ax
    tol = atol + rtol * torch.abs(x)
    if (torch.abs(r) < tol).all():
        return x
    p = r
    r2 = dot_product_function(r, r)

    k = 0
    while k < iters:
        k += 1
        Ap = matrix_vector_product_function(W, V, p)
        pAp = dot_product_function(p, Ap)
        a = r2 / (pAp + 1e-8)
        x = x + a * p
        r = r - a * Ap
        tol = atol + rtol * torch.abs(x)
        if (torch.abs(r) < tol).all():
            break
        r2_new = dot_product_function(r, r)
        beta = r2_new / (r2 + 1e-8)
        r2 = r2_new
        p = r + beta * p

    if torch.isnan(x).any():
        raise ArithmeticError("CG result has nans.")

    return x
