import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
from functools import partial
import torch
from typing import Optional
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from log import make_logger
from glob import glob
from natsort import natsorted
import seaborn as sns


torch.cuda.set_device(1)


from conditional_evaluation import ConditionalEvaluation


dataset = TextImageFilesDataset(image_folder='', caption_file='', image_paths=image_paths, captions=captions, n_digits=0, transform=transform, img_format='png')

path_save_visual = './'

text_sigma = 0.1
image_sigma = 0.5

logger.info(f'text sigma: {text_sigma}, image_sigma: {image_sigma}')
EvalModel = ConditionalEvaluation(sigma=text_sigma)
K_t = EvalModel.gaussian_kernel(txt_feats[:12000], batchsize=64)

EvalModel = ConditionalEvaluation(sigma=image_sigma)
K_i = EvalModel.gaussian_kernel(img_feats[:12000], batchsize=64)
order = 2

cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_i, K_t, order=order, compute_kernel=False)
logger.info(f'order = {order}, for all texts and images: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, images: {torch.exp(x)}, text: {torch.exp(y)}')

text_eigvals, text_eigvecs = torch.linalg.eigh(K_t)
text_eigvals = text_eigvals.real
text_eigvecs = text_eigvecs.real

# vals = text_eigvals.cpu().numpy()
# plt.scatter(vals, [0] * vals.shape[0], s=5, c='blue')
# plt.savefig(f'{folder_name}/text-eigvals.png')
# K_t_cpu = K_t.cpu().numpy()

# # Step 2: Create the heatmap using seaborn
# plt.figure(figsize=(10, 8))  # Adjust the figure size as needed
# sns.heatmap(K_t_cpu, cmap='viridis')  # You can choose a different colormap

# # Step 3: Save the figure
# plt.title('Heatmap of K_t Values')  # Optional: Add a title
# plt.savefig(f'{folder_name}/heatmap_K_t.png', dpi=300, bbox_inches='tight')

top_text = 6
top_image = 5

m, max_id = text_eigvals.topk(top_text)

# text_clusters = {0: [], 1: [], 2: [], 3: [], 4: [], 5: []}

for i in range(top_text):
    top_eig_text = text_eigvecs[:, max_id[i]]

    # Text cluster
    if top_eig_text.sum() < 0:
        top_eig_text = -top_eig_text
    topk_id_text = top_eig_text.argsort(descending=True)[:100]

    prompts = []
    for k, idx in enumerate(topk_id_text.cpu()):
        # text_clusters[i].append(int(idx.cpu()))
        prompts.append(dataset[idx][1])

    file1 = open(f'{folder_name}/text-cluster={i}-prompts.txt', 'w')
    file1.writelines("\n".join(prompts))
    file1.close()

    top_eig_text = top_eig_text.reshape((-1, 1)) # [feature_dim, 1]  TODO: check if it's ok
    K_ut = top_eig_text @ top_eig_text.T
    KoU = K_i * (K_ut)
    KoU = KoU / KoU.trace()

    cond, mutual, joint, x, y = EvalModel.conditional_entropy(K_i, K_ut, order=order, compute_kernel=False)
    logger.info(f'order= {order}, text-cluster={i}: cond: {torch.exp(cond)}, mutual: {torch.exp(mutual)}, images: {torch.exp(x)}, text-cluster: {torch.exp(y)}')

    img_eigvals, img_eigvecs = torch.linalg.eigh(KoU)
    img_eigvals = img_eigvals.real
    img_eigvecs = img_eigvecs.real


    _, max_id_img = img_eigvals.topk(top_image)

    for j in range(top_image):
        top_eig_img = img_eigvecs[:, max_id_img[j]]
        if top_eig_img.sum() < 0:
            top_eig_img = -top_eig_img
        topk_id = top_eig_img.argsort(descending=True)[:36]


        # save_folder_name = os.path.join(args.path_save_visual, 'backbone_{}/{}_{}/'.format(args.backbone, args.visual_name, now_time), 'top{}'.format(i+1))
        save_folder_name = f'{folder_name}/text-cluster={i}/{j}'
        os.makedirs(save_folder_name, exist_ok=True)

        summary = []

        for k, idx in enumerate(topk_id.cpu()):
            print(idx)
            top_imgs = dataset[idx][0]
            summary.append(top_imgs)
            # save_image(top_imgs, os.path.join(save_folder_name, '{}.png'.format(k)), nrow=1)

        # save_image(summary, os.path.join(save_folder_name, 'summary.jpg'), nrow=6)
        save_image(summary[:9], os.path.join(folder_name, f'text={i}_image={j}.jpg'), nrow=3)
