import h5py
import numpy as np
from CKA import linear_CKA, kernel_CKA
import os

if __name__ == '__main__':
    # read an hdf5 file

    # f = h5py.File("roberta_base_openai_clip-vit-large-patch14_image_embeddings_Roboto-Regular__mscoco__.h5", "r")
    # f = h5py.File("llama_base_mscoco.h5", "r")
    # f = h5py.File("llama_base_source.h5", "r")
    # f = h5py.File("llama_base_tiny5.h5", "r")
    # f = h5py.File("llama_base_cedarville.h5", "r")
    for file_name in os.listdir('.'):
        if file_name.endswith('.h5') and file_name.startswith("new_laion2b_mscoco_llama"):
            f = h5py.File(file_name, "r")
        else:
            continue

        # print f datasets
        # print(f.keys())

        # generate 1000 random indices in the range of 0 to 9990
        # indices = np.random.randint(0, 1000, 1000)

        clip_image = np.array(f.get("clip_image_embeddings"))
        clip_text = np.array(f.get("clip_text_embeddings"))
        roberta_text = np.array(f.get("llama_text_embeddings"))

        assert clip_image.shape[0] == clip_text.shape[0] == roberta_text.shape[0]

        # all possible CKA checks

        clip_image_w_ground_truth = linear_CKA(clip_image, roberta_text)
        clip_text_w_ground_truth = linear_CKA(clip_text, roberta_text)
        clip_image_w_clip_text = linear_CKA(clip_image, clip_text)
        clip_image_w_clip_image = linear_CKA(clip_image, clip_image)
        clip_text_w_clip_text = linear_CKA(clip_text, clip_text)
        ground_truth_w_ground_truth = linear_CKA(roberta_text, roberta_text)

        # print('Linear CKA, between clip_image and roberta_text: {}'.format(linear_CKA(clip_image, roberta_text)))
        # print('Linear CKA, between clip_text and roberta_text: {}'.format(linear_CKA(clip_text, roberta_text)))
        # print('Linear CKA, between clip_image and clip_text: {}'.format(linear_CKA(clip_image, clip_text)))
        # print('Linear CKA, between clip_image and clip_image: {}'.format(linear_CKA(clip_image, clip_image)))
        # print('Linear CKA, between clip_text and clip_text: {}'.format(linear_CKA(clip_text, clip_text)))
        # print('Linear CKA, between roberta_text and roberta_text: {}'.format(linear_CKA(roberta_text, roberta_text)))

        # print('Linear CKA, between clip_image and roberta_text: {}'.format(clip_image_w_ground_truth))
        # print('Linear CKA, between clip_text and roberta_text: {}'.format(clip_text_w_ground_truth))
        # print('Linear CKA, between clip_image and clip_text: {}'.format(clip_image_w_clip_text))
        # print('Linear CKA, between clip_image and clip_image: {}'.format(clip_image_w_clip_image))
        # print('Linear CKA, between clip_text and clip_text: {}'.format(clip_text_w_clip_text))
        # print('Linear CKA, between roberta_text and roberta_text: {}'.format(ground_truth_w_ground_truth))

        print(file_name)
        print(f"{clip_image_w_ground_truth:.2f}, {clip_text_w_ground_truth:.2f}, {clip_image_w_clip_text:.2f}, {clip_image_w_clip_image:.2f}, {clip_text_w_clip_text:.2f}, {ground_truth_w_ground_truth:.2f}")
