import numpy as np
from functools import partial
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from functools import partial
import torch
import torch
from typing import Optional
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image
from log import make_logger
from glob import glob
from natsort import natsorted
from conditional_evaluation import ConditionalEvaluation
import seaborn as sns

import numpy as np

from utils import load_and_concatenate_all




animals = ["dog", "cat", "bird", "horse", "cow", "sheep", "elephant", "bear", "zebra", "giraffe"]



# Remove newline characters and concatenate
image_paths, captions, img_feats, txt_feats = load_and_concatenate_all(image_folders, captions_files, image_feats_paths, 
                                                                          text_feats_paths, img_feats_key='dino_features', txt_feats_key='txt_feats', img_format='jpg')
print(img_feats.shape)

txt_arr = np.split(txt_feats, 10)
print(txt_arr[0].shape)
img_arr = np.split(img_feats, 10)

txt_feats = np.concatenate([txt_arr[4], txt_arr[5], txt_arr[6], txt_arr[9], txt_arr[8], txt_arr[7], txt_arr[0], txt_arr[1], txt_arr[2], txt_arr[3]], axis=0)
img_feats = np.concatenate([img_arr[4], img_arr[5], img_arr[6], img_arr[9], img_arr[8], img_arr[7], img_arr[0], img_arr[1], img_arr[2], img_arr[3]], axis=0)

# print(x.shape)
# exit()
cond_breed = []
vendi = []
for i in range(500, 5001, 500):
    logger.info(f'i = {i}, {i // 500} breeds')
    text_sigma = 1.1
    image_sigma = 65

    logger.info(f'text sigma: {text_sigma}, image_sigma: {image_sigma}')
    EvalModel = ConditionalEvaluation(sigma=text_sigma)
    K_t = EvalModel.gaussian_kernel(txt_feats, batchsize=64)

    selected_img_feats = img_feats[:i]
    concatenated_feats = []
    for x in range(5000//500):
        concatenated_feats.append(selected_img_feats)
    concatenated_feats = np.concatenate(concatenated_feats, axis=0)
    concatenated_feats = concatenated_feats[:5000]

    EvalModel = ConditionalEvaluation(sigma=image_sigma)
    K_i = EvalModel.gaussian_kernel(concatenated_feats, batchsize=64)
    order = 1

    cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_i, K_t, order=order, compute_kernel=False)
    logger.info(f'order = {order}, for all texts and images: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, images: {torch.exp(x)}, text: {torch.exp(y)}')
    # order = 2
    # cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_i, K_t, order=order, compute_kernel=False)
    # logger.info(f'order = {order}, for all texts and images: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, images: {torch.exp(x)}, text: {torch.exp(y)}')
    
    vendi.append(float(torch.exp(x).cpu()))
    cond_breed.append(float(torch.exp(cond).cpu()))

print(vendi)
print(cond_breed)
# text_eigvals, text_eigvecs = torch.linalg.eigh(K_t)
# text_eigvals = text_eigvals.real
# text_eigvecs = text_eigvecs.real

# vals = text_eigvals.cpu().numpy()
# plt.scatter(vals, [0] * vals.shape[0], s=5, c='blue')
# plt.savefig(f'{folder_name}/text-eigvals.png')
# K_t_cpu = K_t.cpu().numpy()

# # Step 2: Create the heatmap using seaborn
# plt.figure(figsize=(10, 8))  # Adjust the figure size as needed
# sns.heatmap(K_t_cpu, cmap='viridis')  # You can choose a different colormap

# # Step 3: Save the figure
# plt.title('Heatmap of K_t Values')  # Optional: Add a title
# plt.savefig(f'{folder_name}/heatmap_K_t.png', dpi=300, bbox_inches='tight')
