

import numpy as np
import torch
from tqdm import tqdm


if __name__ == '__main__':
    HASH = 'AE_NRR_z=16'
    MODEL_TYPE = 'lowest_total_loss_with_final_kl_model'
    INPUT_TYPE = 'NRR-avg_sqrt_power'

    arrays = np.load('../runs/mnist/local_equiv_fibers/{}/results_arrays/inference-{}-split=test-input_type={}.npz'.format(HASH, MODEL_TYPE, INPUT_TYPE))
    invariants_ND = arrays['invariants_ND']
    learned_frames_N9 = arrays['learned_frames_N9']
    labels_N = arrays['labels_N']
    rotations_N9 = arrays['rotations_N9']
    images_NF = torch.tensor(arrays['images_NF'])
    rec_images_NF = torch.tensor(arrays['rec_images_NF'])

    print(torch.nn.functional.mse_loss(images_NF, rec_images_NF).item())

    mse = []
    for i in tqdm(range(images_NF.shape[0])):
        mse.append(torch.nn.functional.mse_loss(images_NF[i], rec_images_NF[i]).item())
    print(np.mean(mse))

    from e3nn import o3
    data_irreps = o3.Irreps.spherical_harmonics(10, 1)
    ls_indices = torch.cat([torch.tensor([l]).repeat(2*l+1) for l in data_irreps.ls])

    for l in sorted(list(set(list(ls_indices.numpy())))):
        print(l, torch.nn.functional.mse_loss(images_NF[:, ls_indices == l], rec_images_NF[:, ls_indices == l]).item())
        print()
