import torch
import numpy as np
import random
from PIL import Image
from tqdm import tqdm
import torch.nn.functional as F
from torchvision import transforms
from sklearn.metrics import pairwise_distances
from torch.utils.data import DataLoader, WeightedRandomSampler

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
set_seed(42)

def select_high_entropy_indices(classifier, valid_image_paths, num_samples, device='mps'):    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    classifier.eval()
    classifier.to(device)

    entropy_scores = []

    for img_path in tqdm(valid_image_paths, desc="Calculating entropy"):
        try:
            img = Image.open(img_path).convert('RGB')
            input_tensor = transform(img).unsqueeze(0).to(device)  # shape: [1, C, H, W]

            with torch.no_grad():
                logits = classifier(input_tensor)
                probs = torch.softmax(logits, dim=1).squeeze(0)  # shape: [C]
                entropy = -torch.sum(probs * torch.log(probs + 1e-10)).item()

            entropy_scores.append(entropy)
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            entropy_scores.append(-1e10)  

    entropy_scores = np.array(entropy_scores)
    selected_indices = np.argsort(-entropy_scores)[:num_samples]  

    return selected_indices.tolist()

def explore_latent_space_becs(model,  data_loader, k_per_class=30, entropy_weight=0.7, vae=None, device="mps"):
    model.eval()
    dataset = data_loader.dataset
    original_sampler = data_loader.sampler
    weights = original_sampler.weights if isinstance(original_sampler, WeightedRandomSampler) else torch.ones(len(dataset))

    new_loader = DataLoader(
        dataset,
        batch_size=data_loader.batch_size,
        sampler=WeightedRandomSampler(weights, len(dataset), replacement=False),
        num_workers=data_loader.num_workers,
        drop_last=data_loader.drop_last
    )

    global_info = {
        'all_latents': [],
        'all_labels': [],
        'sample_ids': [],
        'global_indices': [],
        'entropy_scores': []
    }

    with torch.no_grad():
        global_idx = 0
        for images, labels, sample_ids in tqdm(new_loader, desc="Processing"):
            images = images.to(device)
            latents = model.encode(images)
            logits = model(images)

            probs = F.softmax(logits, dim=1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-6), dim=1)

            global_info['all_latents'].append(latents.cpu().numpy())
            global_info['all_labels'].append(labels.cpu().numpy())
            global_info['sample_ids'].append(sample_ids.cpu().numpy())
            global_info['entropy_scores'].append(entropy.cpu().numpy())
            global_info['global_indices'].append(np.arange(global_idx, global_idx + len(images)))
            global_idx += len(images)

    global_info = {k: np.concatenate(v) for k, v in global_info.items()}
    all_latents = global_info['all_latents']
    all_labels = global_info['all_labels']
    entropy_scores = global_info['entropy_scores']

    selected_ids = select_representative_per_class_with_entropy(
        all_latents=all_latents,
        all_labels=all_labels,
        entropy_scores=entropy_scores,  # shape: (N,)
        k_per_class=k_per_class,
        metric='cosine',
        entropy_weight=entropy_weight
    )

    return (
        len(selected_ids),
        selected_ids,
        global_info['sample_ids'][selected_ids].tolist(),
        global_info['all_latents'][selected_ids],
        global_info['all_labels'][selected_ids],
        global_info['all_latents'],
        global_info['all_labels']
    )

def select_representative_per_class_with_entropy(all_latents, all_labels, 
                                    entropy_scores, entropy_weight=0.7,
                                    k_per_class=3, metric='cosine'):

    selected_indices = []
    labels = np.unique(all_labels)

    for label in tqdm(labels, desc="computing per-class selection"):

        class_indices = np.where(all_labels == label)[0]
        class_latents = all_latents[class_indices].reshape(len(class_indices), -1)
        class_entropy = entropy_scores[class_indices]
        
        selected = []
        selected_mask = np.zeros(len(class_indices), dtype=bool)
        # init_idx = np.argmax(class_entropy)
        init_idx = 0
        selected = [class_indices[init_idx]]
        selected_mask[init_idx] = True

        min_dists = pairwise_distances(class_latents, class_latents[[init_idx]], metric=metric).min(axis=1)

        while len(selected) < k_per_class:

            dist_score = 1 - (min_dists - min_dists.min()) / (np.ptp(min_dists) + 1e-8)
            ent_norm = (class_entropy - class_entropy.min()) / (np.ptp(class_entropy) + 1e-8)
            scores = (1 - entropy_weight) * dist_score + entropy_weight * ent_norm
            scores[selected_mask] = -np.inf

            next_idx = np.argmax(scores)
            if scores[next_idx] == -np.inf:
                break 
            selected.append(class_indices[next_idx])
            selected_mask[next_idx] = True

            new_dist = pairwise_distances(class_latents, class_latents[[next_idx]], metric=metric).squeeze()
            min_dists = np.minimum(min_dists, new_dist)

        selected_indices.extend(selected)

    return selected_indices

def get_topk_latents(model, data_loader, topk_ids, device):
    model.eval()

    # topk_ids: in dataset
    global_info = {
        'all_latents': [],
        'all_labels': [],
        'sample_ids': [],
        'topk_latents': [],
        'topk_labels': [],
        'topk_indices': [],  # in all_latents
        'topk_indices_dataset': [] # in dataset
    }
    global_idx = 0
    with torch.no_grad():
        for images, labels, sample_ids in tqdm(data_loader, desc="Extracting latent vectors"):
            images = images.to(device)
            labels = labels.to(device)
            sample_ids = sample_ids.to(device)

            latents = model.encode(images)

            global_info['all_latents'].append(latents.cpu().numpy())
            global_info['all_labels'].append(labels.cpu().numpy())
            global_info['sample_ids'].append(sample_ids.cpu().numpy())
            for i, sample_id in enumerate(sample_ids.cpu().numpy()):
                if sample_id in topk_ids:
                    global_info['topk_latents'].append(latents[i].cpu().numpy())
                    global_info['topk_labels'].append(labels[i].cpu().numpy().item())
                    global_info['topk_indices'].append(global_idx + i)
                    global_info['topk_indices_dataset'].append(sample_id)
        global_idx += len(sample_ids)

    # global_info = {k: np.concatenate(v) for k, v in global_info.items()}

    return (
        np.stack(global_info['topk_indices_dataset'], axis=0),  # topk_indices_dataset (dataset index)
        np.stack(global_info['topk_indices'], axis=0),  # topk_indices (all_latents index)
        np.stack(global_info['topk_latents'], axis=0),  # topk_latents
        np.stack(global_info['topk_labels'], axis=0),  # topk_labels
        np.concatenate(global_info['all_latents']),  # all_latents
        np.concatenate(global_info['all_labels'])  # all_labels
    )

