import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from functools import partial
import torch
from typing import Optional
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from log import make_logger
from glob import glob
from natsort import natsorted
import seaborn as sns


torch.cuda.set_device(0)


from conditional_evaluation import ConditionalEvaluation
dataset = TextImageFilesDataset(image_folder='', caption_file='', image_paths=image_paths, captions=captions, n_digits=6, transform=transform, img_format='JPEG')

path_save_visual = './'

text_sigma = 0.06
image_sigma = 60
        
logger.info(f'text sigma: {text_sigma}, image_sigma: {image_sigma}')
EvalModel = ConditionalEvaluation(sigma=text_sigma)
K_t = EvalModel.gaussian_kernel(txt_feats[:12000], batchsize=64)

EvalModel = ConditionalEvaluation(sigma=image_sigma)
K_i = EvalModel.gaussian_kernel(img_feats[:12000], batchsize=64)
order = 1

cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_t, K_i, order=order, compute_kernel=False)
logger.info(f'conditoinal entorpy of text given image')
logger.info(f'order = {order}, for all texts and images: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, text: {torch.exp(x)}, image: {torch.exp(y)}')

img_eigvals, img_eigvecs = torch.linalg.eigh(K_i)
img_eigvals = img_eigvals.real
img_eigvecs = img_eigvecs.real

vals = img_eigvals.cpu().numpy()
plt.scatter(vals, [0] * vals.shape[0], s=5, c='blue')
plt.savefig(f'{folder_name}/img-eigvals.png')
# K_t_cpu = K_i.cpu().numpy()

# # Step 2: Create the heatmap using seaborn
# plt.figure(figsize=(10, 8))  # Adjust the figure size as needed
# sns.heatmap(K_t_cpu, cmap='viridis')  # You can choose a different colormap

# # Step 3: Save the figure
# plt.title('Heatmap of K_t Values')  # Optional: Add a title
# plt.savefig(f'{folder_name}/heatmap_K_i.png', dpi=300, bbox_inches='tight')

top_text = 7
top_image = 7

m, max_id = img_eigvals.topk(top_image)

# text_clusters = {0: [], 1: [], 2: [], 3: [], 4: [], 5: []}

for i in range(top_image):
    top_eig_img = img_eigvecs[:, max_id[i]]

    # Text cluster
    if top_eig_img.sum() < 0:
        top_eig_img = -top_eig_img
    topk_id_img = top_eig_img.argsort(descending=True)[:350]

    save_folder_name = f'{folder_name}/image-cluster={i}/'
    os.makedirs(save_folder_name, exist_ok=True)

    summary = []
    for k, idx in enumerate(topk_id_img.cpu()):
        top_imgs = dataset[idx][0]
        summary.append(top_imgs)
            # save_image(top_imgs, os.path.join(save_folder_name, '{}.png'.format(k)), nrow=1)
    import random
    random.shuffle(summary)
    save_image(summary[:36], os.path.join(save_folder_name, f'image={i}_summary.jpg'), nrow=6)
    save_image(summary[:16], os.path.join(save_folder_name, f'image={i}_16.jpg'), nrow=4)
    save_image(summary[:9], os.path.join(save_folder_name, f'image={i}.jpg'), nrow=3)



    top_eig_img = top_eig_img.reshape((-1, 1)) # [feature_dim, 1]  TODO: check if it's ok
    K_ui = top_eig_img @ top_eig_img.T
    KoU = K_t * (K_ui)
    KoU = KoU / KoU.trace()

    cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_i, K_ui, order=order, compute_kernel=False)
    logger.info(f'order= {order}, img-cluster={i}: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, images: {torch.exp(x)}, text-cluster: {torch.exp(y)}')

    top_txt_eigvals, top_txt_eigvecs = torch.linalg.eigh(KoU)
    top_txt_eigvals = top_txt_eigvals.real
    top_txt_eigvecs = top_txt_eigvecs.real
    vals = top_txt_eigvals.cpu().numpy()
    plt.scatter(vals, [0] * vals.shape[0], s=5, c='blue')
    plt.savefig(f'{folder_name}/txt-eigvals={i}.png')

    _, max_id_img = top_txt_eigvals.topk(top_text)

    for j in range(top_text):
        top_eig_txt = top_txt_eigvecs[:, max_id_img[j]]
        if top_eig_txt.sum() < 0:
            top_eig_txt = -top_eig_txt
        topk_id = top_eig_txt.argsort(descending=True)[:100]

        prompts = []
        for k, idx in enumerate(topk_id.cpu()):
            prompts.append(dataset[idx][1])
            # print(idx)

        file1 = open(f'{save_folder_name}/img-cluster={i}-text={j}-prompts.txt', 'w')
        file1.writelines("\n".join(prompts))
        file1.close()
