import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances
import re


def empirical_covariance(x):
    x_centered = x - np.mean(x, axis=0)
    return x_centered.T @ x_centered / (x.shape[0] - 1)
    
def incremental_cov(S, S2, n):
    mean = S / n
    return S2 / n - mean.unsqueeze(1) @ mean.unsqueeze(0)

from sklearn.decomposition import PCA

def cover_maxmin_batched(real_feats, gen_feats, count, batch_size=1024):
    real_feats = real_feats.astype(np.float32)
    gen_feats = gen_feats.astype(np.float32)

    N = len(gen_feats)
    used = np.zeros(N, dtype=bool)
    matches = []

    min_dists = np.full(N, np.inf, dtype=np.float32)

    for start in range(0, N, batch_size):
        end = min(start + batch_size, N)
        batch = gen_feats[start:end]  

        d = np.linalg.norm(batch[:, None, :] - real_feats[None, :, :], axis=2)  # (B, R)
        min_dists[start:end] = np.min(d, axis=1)

    for _ in range(count):
        available_idx = np.where(~used)[0]
        if len(available_idx) == 0:
            break

        best_local = np.argmax(min_dists[available_idx])
        best_global = available_idx[best_local]

        used[best_global] = True
        matches.append(best_global)

        x_new = gen_feats[best_global]
        for start in range(0, len(available_idx), batch_size):
            end = min(start + batch_size, len(available_idx))
            batch_idx = available_idx[start:end]
            batch_feats = gen_feats[batch_idx] 

            diff = batch_feats - x_new[None, :]
            new_dists = np.linalg.norm(diff, axis=1)

            min_dists[batch_idx] = np.minimum(min_dists[batch_idx], new_dists)

    return used,matches

import torch
import numpy as np
from sklearn.decomposition import PCA

def frobenius_greedy_sampling(
    real_feats, gen_feats, count, matches, pca_dim=32, patience_steps=10
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # PCA projection
    # pca = PCA(n_components=pca_dim)
    # real_feats_pca = pca.fit_transform(real_feats)
    # gen_feats_pca = pca.transform(gen_feats)

    real_feats_torch = torch.tensor(real_feats, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats, dtype=torch.float32, device=device)

    # Target covariance
    X_centered = real_feats_torch - real_feats_torch.mean(dim=0, keepdim=True)
    cov_target = X_centered.T @ X_centered / (real_feats_torch.shape[0] - 1)

    used = torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device)
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0
    patience_counter = 0
    candidate_idx = None
    candidate_score = float('inf')

    while selected_count < count:
        available_idx = torch.where(~used)[0]
        if len(available_idx) == 0:
            break

        sampled_idx = available_idx[torch.randint(0, len(available_idx), (1,)).item()]
        x = gen_feats_torch[sampled_idx]

        S_try = sum_x + x
        S2_try = sum_xxT + x.unsqueeze(1) @ x.unsqueeze(0)
        n_try = selected_count + 1

        mean_try = S_try / n_try
        cov_try = S2_try / n_try - mean_try.unsqueeze(1) @ mean_try.unsqueeze(0)

        diff = cov_try - cov_target
        frob_score = torch.norm(diff, p='fro').item()

        if frob_score < candidate_score:
            candidate_score = frob_score
            candidate_idx = sampled_idx
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience_steps:
            # Accept current candidate
            x = gen_feats_torch[candidate_idx]
            sum_x += x
            sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
            used[candidate_idx] = True
            matches.append(candidate_idx.item())
            selected_count += 1
            # Reset
            candidate_idx = None
            candidate_score = float('inf')
            patience_counter = 0

    return used.to('cpu'), matches

def nearest_to_center(real_feats,gen_feats, count):
    real_mean = np.mean(real_feats, axis=0)
    gen_feats = gen_feats.copy()
    dists = np.linalg.norm(gen_feats - real_mean, axis=1)
    nearest_indices = np.argsort(dists)[:count]
    used = np.zeros(len(gen_feats), dtype=bool)
    matches = nearest_indices.tolist()
    used[nearest_indices] = True
    return used, matches


# RKE OF THE SHIFT?
def cov_trace_ratio_with_pca(real_feats, gen_feats, count, matches, pca_dim=32):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Step 1: PCA
    pca = PCA(n_components=pca_dim)
    real_feats_pca = pca.fit_transform(real_feats)
    gen_feats_pca = pca.transform(gen_feats)

    real_feats_torch = torch.tensor(real_feats_pca, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats_pca, dtype=torch.float32, device=device)

    # Step 2: Covariance of real features
    X = real_feats_torch
    X_centered = X - X.mean(dim=0, keepdim=True)
    cov_real = X_centered.T @ X_centered / (X.shape[0] - 1)
    cov_real_inv = torch.linalg.pinv(cov_real)

    # Init
    used = torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device)
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0

    for count_idx in range(count):
        available_idx = torch.where(~used)[0]
        if len(available_idx) == 0:
            break

        candidates = gen_feats_torch[available_idx]

        # Hypothetical updated stats if each candidate is added
        S_try = sum_x.unsqueeze(0) + candidates                                 # [B, D]
        S2_try = sum_xxT.unsqueeze(0) + candidates.unsqueeze(2) @ candidates.unsqueeze(1)  # [B, D, D]
        n_try = selected_count + 1

        mean_try = S_try / n_try                                               # [B, D]
        cov_try = S2_try / n_try - mean_try.unsqueeze(2) @ mean_try.unsqueeze(1)  # [B, D, D]

        # Compute trace(cov_real_inv @ cov_try) for each candidate
        traces = torch.einsum('ij,bij->b', cov_real_inv, cov_try)  # [B]

        # Select candidate with maximum trace
        best_idx_in_batch = torch.argmax(traces)
        best_idx = available_idx[best_idx_in_batch]

        # Update accumulators
        x = gen_feats_torch[best_idx]
        sum_x += x
        sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
        used[best_idx] = True
        selected_count += 1
        matches.append(best_idx.item())

    return used.to('cpu'), matches


def matching_pursuit_cov_frobenius_with_pca_batches(
    real_feats, gen_feats, count, matches, pca_dim=32, batch_size=200
):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # PCA projection
    pca = PCA(n_components=pca_dim)
    real_feats_pca = pca.fit_transform(real_feats)
    gen_feats_pca = pca.transform(gen_feats)

    real_feats_torch = torch.tensor(real_feats_pca, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats_pca, dtype=torch.float32, device=device)

    # Target covariance
    X = real_feats_torch
    X_centered = X - X.mean(dim=0, keepdim=True)
    cov_target = X_centered.T @ X_centered / (X.shape[0] - 1)

    used = torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device)
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0

    for count_idx in range(count):
        available_idx = torch.where(~used)[0]
        if len(available_idx) == 0:
            break

        best_score = float('inf')
        best_idx = None

        with torch.no_grad():
            for start in range(0, len(available_idx), batch_size):
                end = min(start + batch_size, len(available_idx))
                batch_indices = available_idx[start:end]
                batch = gen_feats_torch[batch_indices]

                S_try = sum_x.unsqueeze(0) + batch
                S2_try = sum_xxT.unsqueeze(0) + batch.unsqueeze(2) @ batch.unsqueeze(1)
                n_try = selected_count + 1

                mean_try = S_try / n_try
                cov_try = S2_try / n_try - mean_try.unsqueeze(2) @ mean_try.unsqueeze(1)

                diff = cov_try - cov_target.unsqueeze(0)
                frob_norms = torch.norm(diff, p='fro', dim=(1, 2))

                min_frob, min_idx = torch.min(frob_norms, dim=0)

                if min_frob.item() < best_score:
                    best_score = min_frob.item()
                    best_idx = batch_indices[min_idx]

        if best_idx is not None:
            x = gen_feats_torch[best_idx]
            sum_x += x
            sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
            used[best_idx] = True
            selected_count += 1
            matches.append(best_idx.item())

    return used.to('cpu'), matches

def matching_pursuit_cov_frobenius_with_pca(real_feats, gen_feats, count, matches, pca_dim=32):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    pca = PCA(n_components=pca_dim)
    real_feats_pca = pca.fit_transform(real_feats)
    gen_feats_pca = pca.transform(gen_feats)

    real_feats_torch = torch.tensor(real_feats_pca, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats_pca, dtype=torch.float32, device=device)

    X = real_feats_torch
    X_centered = X - X.mean(dim=0, keepdim=True)
    cov_target = X_centered.T @ X_centered / (X.shape[0] - 1)

    used = torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device)
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0
    for count_idx in range(count):
        available_idx = torch.where(~used)[0]
        if len(available_idx) == 0:
            break

        with torch.no_grad():
            available_idx = torch.where(~used)[0]
            if len(available_idx) == 0:
                return used.to('cpu')
        
            candidates = gen_feats_torch[available_idx] 
        
            # Expand sum_x and sum_xxT to batch update
            S_try = sum_x.unsqueeze(0) + candidates                   
            S2_try = sum_xxT.unsqueeze(0) + candidates.unsqueeze(2) @ candidates.unsqueeze(1)  
            n_try = selected_count + 1
            mean_try = S_try / n_try                                  
            cov_try = S2_try / n_try - mean_try.unsqueeze(2) @ mean_try.unsqueeze(1)  
        
            diff = cov_try - cov_target.unsqueeze(0)                   
            frob_norms = torch.norm(diff, p='fro', dim=(1, 2))        
        
            # Find best
            best_idx_in_batch = torch.argmin(frob_norms)
            best_idx = available_idx[best_idx_in_batch]
        
            # Update state
            x = gen_feats_torch[best_idx]
            sum_x += x
            sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
            used[best_idx] = True
            selected_count += 1
            matches.append(best_idx.item())

    return used.to('cpu'),matches
    
def match_greedy(real_feats, gen_feats,count, distance='l2-near',zero_centered=False):
    real_feats = real_feats.copy()
    gen_feats = gen_feats.copy()
    if zero_centered:
        real_feats -= np.mean(real_feats,axis=0)
        gen_feats -= np.mean(gen_feats,axis=0)
        
    used = np.zeros(len(gen_feats), dtype=bool)
    matches = []
    counter = 0

    if distance == "MPfast":
        used,matches = matching_pursuit_cov_frobenius_with_pca(real_feats, gen_feats, count, [], pca_dim=32)

    elif distance == "nearest-to-center":
        used, matches = nearest_to_center(real_feats, gen_feats, count)

    elif distance == "CovTrace":
        used,matches = cov_trace_ratio_with_pca(real_feats, gen_feats, count, [], pca_dim=32)

    elif distance == "greedy-frobenius":
        used,matches = frobenius_greedy_sampling(real_feats,gen_feats,count,[])
    elif distance == 'cov-frobenius':
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        real_feats = torch.tensor(real_feats, dtype=torch.float32, device=device)
        gen_feats = torch.tensor(gen_feats, dtype=torch.float32, device=device)
    
        X = real_feats
        X_centered = X - X.mean(dim=0, keepdim=True)
        cov_real = X_centered.T @ X_centered / (X.shape[0] - 1)
    
        used = torch.zeros(len(gen_feats), dtype=torch.bool, device=device)
        # selected = []
    
        S = torch.zeros(gen_feats.shape[1], device=device)
        S2 = torch.zeros(gen_feats.shape[1], gen_feats.shape[1], device=device)
        n_sel = 0

        for roundi in range(count):
            available_idx = torch.where(~used)[0]
            if len(available_idx) == 0:
                break
            sample_cap = 800
            if len(available_idx) > sample_cap:
                sample_idx = torch.randperm(len(available_idx), device=device)[:sample_cap]
                available_idx = available_idx[sample_idx]
    
            best_dist = float('inf')
            best_idx = None
    
            for i in available_idx:
                x = gen_feats[i]
                S_try = S + x
                S2_try = S2 + x.unsqueeze(1) @ x.unsqueeze(0)
                cov_try = incremental_cov(S_try, S2_try, n_sel + 1)
                frob = torch.norm(cov_try - cov_real, p='fro')
    
                if frob.item() < best_dist:
                    best_dist = frob.item()
                    best_idx = i
    
            if best_idx is not None:
                x = gen_feats[best_idx]
                S += x
                S2 += x.unsqueeze(1) @ x.unsqueeze(0)
                n_sel += 1
                used[best_idx] = True
                matches.append(best_idx.item())
                # selected.append(best_idx)
        used = used.to('cpu')

    elif distance == 'cover-maxmin-batched':
        used, matches = cover_maxmin_batched(real_feats, gen_feats, count, batch_size=512)

    elif distance == 'kmeans-diverse':
        real_feats = real_feats.astype(np.float32)
        gen_feats = gen_feats.astype(np.float32)
    
        k = min(count, 6) 
        kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42)
        labels = kmeans.fit_predict(gen_feats)
    
        used = np.zeros(len(gen_feats), dtype=bool)
        matches = []
    
        cluster_to_indices = {i: np.where(labels == i)[0].tolist() for i in range(k)}
    
        while len(matches) < count:
            for cluster_id in range(k):
                candidates = cluster_to_indices.get(cluster_id, [])
                candidates = [idx for idx in candidates if not used[idx]]
                if candidates:
                    chosen = np.random.choice(candidates)
                    used[chosen] = True
                    matches.append(chosen)
                    if len(matches) >= count:
                        break
        
    elif distance == 'cover-maxmin-improved':
        real_feats = real_feats.astype(np.float32)
        gen_feats = gen_feats.astype(np.float32)
    
        used = np.zeros(len(gen_feats), dtype=bool)
        matches = []
    
        dists = np.linalg.norm(gen_feats[:, None, :] - real_feats[None, :, :], axis=2)
        min_dists = np.min(dists, axis=1)
    
        for _ in range(count):
            available_idx = np.where(~used)[0]
            if len(available_idx) == 0:
                break
    
            best_local = np.argmax(min_dists[available_idx])
            best_global = available_idx[best_local]
    
            used[best_global] = True
            matches.append(best_global)
    
            x_new = gen_feats[best_global]
    
            diff = gen_feats[available_idx] - x_new[None, :]
            new_dists = np.linalg.norm(diff, axis=1)
    
            min_dists[available_idx] = np.minimum(min_dists[available_idx], new_dists)
                
    elif distance == 'cover-maxmin':
        dists = np.linalg.norm(
            real_feats[None, :, :] - gen_feats[:, None, :], axis=2
        ) 

        min_dists = np.min(dists, axis=1)  

        for _ in range(count):
            available_idx = np.where(~used)[0]
            if len(available_idx) == 0:
                break

            best_local = np.argmax(min_dists[available_idx])
            best_global = available_idx[best_local]

            used[best_global] = True
            matches.append(best_global)

            min_dists[best_global] = -np.inf
    elif distance.startswith('l2-near-pca-'):
        match = re.search(r'l2-near-pca-(\d+)', distance)
        if match:
            pca_dim = int(match.group(1))
        else:
            raise ValueError(f"Could not extract PCA dimension from distance string: {distance}")
        pca = PCA(n_components=pca_dim)
        real_feats_pca = pca.fit_transform(real_feats)
        gen_feats_pca = pca.transform(gen_feats)
        while counter < count:
            for r in real_feats_pca:
                if counter >= count:
                    break
                counter += 1
                available_idx = np.where(~used)[0]
                available_gen = gen_feats_pca[available_idx]
                dists = np.linalg.norm(available_gen - r.reshape(1, -1), axis=1)
                best_local = np.argmin(dists)
                best_local = np.argmin(dists)
                best_global = available_idx[best_local]
        
                used[best_global] = True
                matches.append(best_global)
    else:
        while counter < count:
            for r in real_feats:
                if counter >= count:
                    break
                counter += 1
                available_idx = np.where(~used)[0]
                available_gen = gen_feats[available_idx]
                
                if distance == 'cosine-near':
                    dists = cosine_distances(r.reshape(1, -1), available_gen).flatten()
                    best_local = np.argmin(dists)
                elif distance == 'cosine-far':
                    dists = cosine_distances(r.reshape(1, -1), available_gen).flatten()
                    best_local = np.argmax(dists)
                elif distance == 'l2-near':
                    dists = np.linalg.norm(available_gen - r.reshape(1, -1), axis=1)
                    best_local = np.argmin(dists)
                elif distance == 'l2-far':
                    dists = np.linalg.norm(available_gen - r.reshape(1, -1), axis=1)
                    best_local = np.argmax(dists)
                else:
                    print("Distance Is Not Defined")
                    assert 0==1
                    
                best_local = np.argmin(dists)
                best_global = available_idx[best_local]
        
                used[best_global] = True
                matches.append(best_global)

    return matches,used