import torch
import os
import logging
import itertools
import random
# import cv2
import numpy as np
from torchvision.utils import save_image
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(42)

def is_valid_image(img, black_threshold=0.98, white_threshold=0.98):

    img_gray = img.convert("L")
    hist = img_gray.histogram()
    
    black_ratio = hist[0] / sum(hist)
    if black_ratio > black_threshold:
        return False
    
    white_ratio = hist[-1] / sum(hist)
    if white_ratio > white_threshold:
        return False
    
    entropy = image_entropy(img_gray)
    return entropy > 1.5  

def image_entropy(img):
    """计算图像熵（衡量信息量）"""
    hist = np.array(img.histogram())
    hist = hist / hist.sum()
    hist = hist[hist > 0]
    return -np.sum(hist * np.log2(hist))

def process_dm_generated_images(images, model, size=None, device='mps'):

    transform = transforms.Compose([
        transforms.Resize((224, 224)), 
        transforms.ToTensor(),        
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
    ])


    image_tensors = [transform(image) for image in images]  
    image_batch = torch.stack(image_tensors, dim=0)        

    image_batch = image_batch.to(device)
    model = model.to(device)

    with torch.no_grad():
        latents = model.encode(image_batch) 

    return latents

def dm_interpolated_images(
    topk_indices_dataset, topk_labels, data, dm_pipeline, classifier, device, config_dm, logdir, gen_dir, number_pairs=20
):

    if len(topk_labels) < 2:
        raise ValueError("Not enough samples in topk_indices to form pairs.")
    
    all_pairs = [(i, j) for i, j in itertools.combinations(range(len(topk_labels)), 2) if topk_labels[i] != topk_labels[j]]
    if len(all_pairs) < number_pairs:
        selected_pairs = all_pairs
    else:
        selected_pairs = random.sample(all_pairs, number_pairs)
    print(f"Selected {len(selected_pairs)} pairs for interpolation.")
    latents_dm_generated = []
    images_dm_generated = []
    
    for i, j in tqdm(selected_pairs, desc="Processing random pairs"):
        dataset_idx_0 = topk_indices_dataset[i]
        dataset_idx_1 = topk_indices_dataset[j]
        
        image_0 = np.transpose(data.dataloaders['train'].dataset[dataset_idx_0][0].numpy(), (1, 2, 0))
        image_1 = np.transpose(data.dataloaders['train'].dataset[dataset_idx_1][0].numpy(), (1, 2, 0))
        
        unique_id = f"{dataset_idx_0}_vs_{dataset_idx_1}"
        
        images_dm, latents_dm = dm_pipeline(
            img_0=image_0,
            img_1=image_1,
            idx_0=dataset_idx_0,
            idx_1=dataset_idx_1,
            label_0=topk_labels[i],
            label_1=topk_labels[j],
            prompt_0="",
            prompt_1="",
            save_lora_dir=os.path.join(logdir, "trained_lora", gen_dir),
            load_lora_path_0=None,
            load_lora_path_1=None,
            use_adain=config_dm.use_adain,
            use_reschedule=config_dm.use_reschedule,
            lamd=config_dm.lamb,
            output_path=os.path.join(logdir, "generated_images", gen_dir, unique_id),
            num_frames=config_dm.num_frames,
            fix_lora=None,
            save_intermediates=config_dm.save_inter,
            use_lora=not config_dm.no_lora
        )
        
        latents_dm_generated.extend([tensor.cpu().numpy() for tensor in latents_dm])
        images_dm_generated.extend(images_dm)
    
    latents_dm_encoded = process_dm_generated_images(images_dm_generated, classifier, image_0.shape[:1], device).cpu().numpy()
    
    return latents_dm_encoded, images_dm_generated

