import numpy as np
import torch
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import re
import torch.nn.functional as F

def Text_matching(real_feats, gen_feats, count, text_feature, used=None):
    gen_feats = gen_feats.copy()
    gen_feats = gen_feats / np.linalg.norm(gen_feats, axis=1, keepdims=True)
    text_feature = text_feature / np.linalg.norm(text_feature)
    used = used if used is not None else np.zeros(len(gen_feats), dtype=bool)
    matches = []

    dists = np.linalg.norm(gen_feats - text_feature, axis=1)
    unavailable_indices = np.where(used)[0]
    dists[unavailable_indices] = np.inf

    closest_indices = np.argsort(dists)[:count]

    matches = closest_indices.tolist()
    return matches

def random_selection(gen_feats, count, used):
    available_indices = torch.where(~used)[0]
    if len(available_indices) < count:
        raise ValueError("Not enough available features to select the requested count.")

    selected_indices = torch.randperm(len(available_indices))[:count]
    matches = selected_indices.tolist()

    return matches

def prune_features_by_meandistance(real_feats, gen_feats, cutoff_percent=0.95):
    real_mean = np.mean(real_feats, axis=0)
    gen_feats = gen_feats.copy()
    cutoff = int(len(gen_feats) * cutoff_percent)
    dists = np.linalg.norm(gen_feats - real_mean, axis=1)
    nearest_indices = np.argsort(dists)[cutoff:]
    matches = nearest_indices.tolist()
    used = np.zeros(len(gen_feats), dtype=bool)
    used[nearest_indices] = True
    return  matches

def prune_features_and_tensors_by_textsimilarity(gen_feats, text_feature, cutoff_percent=0.98):
    gen_feats = gen_feats.copy()
    gen_feats = gen_feats / np.linalg.norm(gen_feats, axis=1, keepdims=True)
    text_feature = text_feature / np.linalg.norm(text_feature)
    dists = np.linalg.norm(gen_feats - text_feature, axis=1)
    cutoff = int(len(gen_feats) * cutoff_percent)
    worst_indices = np.argsort(dists)[cutoff:]
    return worst_indices

def Center_matching(real_feats,gen_feats, count, used=None):
    real_mean = np.mean(real_feats, axis=0)
    gen_feats = gen_feats.copy()
    dists = np.linalg.norm(gen_feats - real_mean, axis=1)
    unavailable_indices = np.where(used)[0] if used is not None else []
    dists[unavailable_indices] = np.inf
    nearest_indices = np.argsort(dists)[:count]
    used = np.zeros(len(gen_feats), dtype=bool) if used is None else used
    matches = nearest_indices.tolist()
    used[nearest_indices] = True
    return  matches

def empirical_covariance_torch(x):
    x = x - x.mean(dim=0)
    cov = x.T @ x / (x.shape[0] - 1)
    return cov

def fixed_point_eq(alpha, lambdas, n, p, n1):

    alpha = alpha.unsqueeze(-1)/n
    denom = lambdas * alpha + 1 - (p / n) - alpha
    lhs = torch.sum(1.0 / denom, dim=1)
    rhs = (p + n * alpha.squeeze(-1) - n1) / (1 - (p / n) - alpha.squeeze(-1))
    return lhs - rhs

def solve_alpha_batch(lambdas_batch, n, p, n1, tol=1e-5, max_iter=30):
    """
    lambdas_batch: [B, D]
    Returns: [B] (alpha solutions)
    """
    device = lambdas_batch.device
    batch_size = lambdas_batch.shape[0]
    lower = torch.full((batch_size,), 1e-8, dtype=torch.float64, device=device)

    upper = torch.full((batch_size,), n - p - 1e-6, dtype=torch.float64, device=device)
    for _ in range(max_iter):
        mid = (lower + upper) / 2

        f_mid = fixed_point_eq(mid, lambdas_batch, n, p, n1)
        f_lower = fixed_point_eq(lower, lambdas_batch, n, p, n1)

        mask = f_mid * f_lower < 0

        upper = torch.where(mask, mid, upper)
        lower = torch.where(~mask, mid, lower)

        if torch.max(torch.abs(upper-lower)) < tol:
            break

    return mid/n

def calculate_alpha_estimate(eigvalues, n, n1, p,max_iter=30):

    alpha_solutions = solve_alpha_batch(eigvalues, n=n, p=p, n1=n1,max_iter=max_iter)
    return alpha_solutions

def matching_alpha_orig(real_feats, gen_feats, count, matches, pca_dim=32,used=None):
    return matching_alpha(real_feats, gen_feats, count, matches, pca_dim=pca_dim, used=used,use_orig=True)

def matching_alpha(real_feats, gen_feats, count, matches, pca_dim=32,used=None,use_orig=False):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    orig_dim = gen_feats.shape[1]
    if pca_dim is None:
        real_feats_pca = real_feats
        gen_feats_pca = gen_feats
        pca_dim = gen_feats.shape[1]
    else:
        pca = PCA(n_components=pca_dim)

        real_feats_pca = pca.fit_transform(real_feats)
        gen_feats_pca = pca.transform(gen_feats)
        if use_orig:
            pca_dim = orig_dim

    real_feats_torch = torch.tensor(real_feats_pca, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats_pca, dtype=torch.float32, device=device)
    n_2 = real_feats_torch.shape[0]
    cov_real = empirical_covariance_torch(real_feats_torch)
    eigvals, eigvecs = torch.linalg.eigh(cov_real)
    eigvals = torch.clamp(eigvals, min=1e-6)
    S_inv_sqrt = eigvecs @ torch.diag(1.0 / torch.sqrt(eigvals)) @ eigvecs.T

    used = torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device) if used is None else used
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0
    for count_idx in range(count):

        with torch.no_grad():
            available_idx = torch.where(~used)[0]
            if len(available_idx) == 0:
                return matches

            candidates = gen_feats_torch[available_idx]

            S_try = sum_x.unsqueeze(0) + candidates
            S2_try = sum_xxT.unsqueeze(0) + candidates.unsqueeze(2) @ candidates.unsqueeze(1)
            n_try = selected_count + 1
            mean_try = S_try / n_try
            cov_try = S2_try / n_try - mean_try.unsqueeze(2) @ mean_try.unsqueeze(1)

            if n_try > 1: cov_try = cov_try * n_try / (n_try - 1)
            product = torch.matmul(S_inv_sqrt.unsqueeze(0), cov_try)
            product = torch.matmul(product, S_inv_sqrt.unsqueeze(0))
            eigvals = torch.linalg.eigvalsh(product)
            alpha_values = calculate_alpha_estimate(eigvals, n=count_idx+1+n_2, n1=count_idx+1, p=pca_dim)

            best_idx_in_batch = torch.argmin(alpha_values)
            best_idx = available_idx[best_idx_in_batch]

            x = gen_feats_torch[best_idx]
            sum_x += x
            sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
            used[best_idx] = True
            selected_count += 1
            matches.append(best_idx.item())

    return matches

def Covariance_matching(real_feats, gen_feats, count, matches, pca_dim=32, scale=0, constraint_scale=False,used=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if pca_dim is None:
        real_feats_pca = real_feats
        gen_feats_pca = gen_feats
        pca_dim = gen_feats.shape[1]
    else:
        pca = PCA(n_components=pca_dim)
        real_feats_pca = pca.fit_transform(real_feats)
        gen_feats_pca = pca.transform(gen_feats)

    real_feats_torch = torch.tensor(real_feats_pca, dtype=torch.float32, device=device)
    gen_feats_torch = torch.tensor(gen_feats_pca, dtype=torch.float32, device=device)
    X = real_feats_torch
    X_centered = X - X.mean(dim=0, keepdim=True)
    cov_target = X_centered.T @ X_centered / (X.shape[0] - 1)
    p = cov_target.shape[0]
    used = used if used is not None else torch.zeros(len(gen_feats_torch), dtype=torch.bool, device=device)
    sum_x = torch.zeros(gen_feats_torch.shape[1], device=device)
    sum_xxT = torch.zeros(gen_feats_torch.shape[1], gen_feats_torch.shape[1], device=device)
    selected_count = 0
    if constraint_scale:
        cov_target = cov_target / torch.norm(cov_target, p='fro')

    available_idx = torch.where(~used)[0]
    random_candidate_idx = torch.randint(0, len(available_idx), (1,)).item()
    x = gen_feats_torch[available_idx[random_candidate_idx]]
    sum_x += x
    sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
    used[available_idx[random_candidate_idx]] = True
    selected_count += 1
    matches.append(available_idx[random_candidate_idx].item())

    for count_idx in range(1,count):
        available_idx = torch.where(~used)[0]
        if len(available_idx) == 0:
            break

        with torch.no_grad():
            available_idx = torch.where(~used)[0]
            if len(available_idx) == 0:
                return matches

            candidates = gen_feats_torch[available_idx]

            S_try = sum_x.unsqueeze(0) + candidates
            S2_try = sum_xxT.unsqueeze(0) + candidates.unsqueeze(2) @ candidates.unsqueeze(1)
            n_try = selected_count + 1
            mean_try = S_try / n_try
            cov_try = S2_try / n_try - mean_try.unsqueeze(2) @ mean_try.unsqueeze(1)
            if n_try > 1: cov_try = cov_try * n_try / (n_try - 1)
            scale_loss = 0
            if constraint_scale:
                scale_loss = torch.norm(cov_try, p='fro', dim=(1,2))
                cov_try = cov_try / torch.norm(cov_try, p='fro', dim=(1,2)).unsqueeze(1).unsqueeze(2)

            diff = cov_try - cov_target.unsqueeze(0)
            frob_norms = torch.norm(diff, p='fro', dim=(1, 2))
            if constraint_scale:
                loss = frob_norms - 1/10_000 * scale * scale_loss
            else:
                loss = frob_norms

            best_idx_in_batch = torch.argmin(loss)
            best_idx = available_idx[best_idx_in_batch]

            x = gen_feats_torch[best_idx]
            sum_x += x
            sum_xxT += x.unsqueeze(1) @ x.unsqueeze(0)
            used[best_idx] = True
            selected_count += 1
            matches.append(best_idx.item())

    return matches

def Center_sampling(real_feats, gen_feats, count, used=None):

    real_feats = real_feats.copy()

    real_mean = np.mean(real_feats, axis=0)
    gen_feats = gen_feats.copy()

    similarities = np.dot(gen_feats, real_mean)

    similarities = np.clip(similarities, a_min=1e-8, a_max=None)
    if used is not None:
        unavailable_indices = np.where(used)[0]
        similarities[unavailable_indices] = 0
    probabilities = similarities / np.sum(similarities)
    indices = np.random.choice(len(gen_feats), size=count, replace=False, p=probabilities)
    matches = indices.tolist()
    return matches

def Text_sampling(real_feats, gen_feats, count, text_feature, used=None):

    text_feature = text_feature.copy()

    gen_feats = gen_feats.copy()

    similarities = np.dot(gen_feats, text_feature)
    similarities = np.clip(similarities, a_min=1e-8, a_max=None)

    if used is not None:
        unavailable_indices = np.where(used)[0]
        similarities[unavailable_indices] = 0

    probabilities = similarities / np.sum(similarities)
    indices = np.random.choice(len(gen_feats), size=count, replace=False, p=probabilities)

    matches = indices.tolist()

    return matches

def augpaper_closest_to_genmean(real_feats, gen_feats, count, used=None):
    gen_feats = gen_feats.copy()

    gen_mean = np.mean(gen_feats, axis=0)
    used = used if used is not None else np.zeros(len(gen_feats), dtype=bool)
    matches = []

    dists = np.linalg.norm(gen_feats - gen_mean, axis=1)
    if used is not None:
        unavailable_indices = np.where(used)[0]
        dists[unavailable_indices] = np.inf
    closest_indices = np.argsort(dists)[:count]
    matches = closest_indices.tolist()
    return matches


def K_mean(real_feats, gen_feats, count, used=None):
    gen_feats = torch.tensor(gen_feats, dtype=torch.float32)
    n_clusters = count
    gen_feats = gen_feats.clone()
    available_indices = np.where(~used)[0] if used is not None else np.arange(len(gen_feats))
    unavailable_indices = np.where(used)[0] if used is not None else []
    availabe_gen_feats = gen_feats[available_indices]

    gen_feats_np = availabe_gen_feats.cpu().numpy()

    kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0)
    kmeans.fit(gen_feats_np)
    centers = torch.tensor(kmeans.cluster_centers_, dtype=gen_feats.dtype, device=gen_feats.device)

    dists = torch.cdist(centers, gen_feats, p=2)
    if used is not None:
        dists[:, unavailable_indices] = float('inf')

    used = torch.zeros(len(gen_feats), dtype=torch.bool, device=gen_feats.device)
    matches = []

    for i in range(n_clusters):
        center_dists = dists[i]
        available = (~used).nonzero(as_tuple=True)[0]
        closest_idx = available[torch.argmin(center_dists[available])]
        used[closest_idx] = True
        matches.append(closest_idx.item())

    return matches

def ds3(real_feats, gen_feats, count, used=None,n_clusters=200):
    used = used.cpu().numpy() if isinstance(used, torch.Tensor) else used
    gen_feats = torch.tensor(gen_feats, dtype=torch.float32)
    gen_feats = gen_feats.clone()
    available_indices = np.where(~used)[0] if used is not None else np.arange(len(gen_feats))
    unavailable_indices = np.where(used)[0] if used is not None else []
    availabe_gen_feats = gen_feats[available_indices]
    real_feats = torch.tensor(real_feats, dtype=torch.float32)

    gen_feats_np = availabe_gen_feats.cpu().numpy()

    kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0)
    kmeans.fit(gen_feats_np)
    centers = torch.tensor(kmeans.cluster_centers_, dtype=gen_feats.dtype, device=gen_feats.device)

    dists = torch.cdist(centers, gen_feats, p=2)
    cluster_assignments = torch.argmin(dists,dim=0)
    if used is not None:
        cluster_assignments[unavailable_indices] = -1

    real_dists = torch.cdist(centers, real_feats, p=2)

    closest_centers = torch.argmin(real_dists, dim=0)

    closest_centers = set(closest_centers.tolist())

    matches = torch.tensor([], dtype=torch.long, device=gen_feats.device)
    for center_idx in closest_centers:
        center_gen_feats = torch.where(cluster_assignments == center_idx)[0]
        if len(center_gen_feats) == 0:
            continue

        matches = torch.cat((matches, center_gen_feats.to(gen_feats.device)), dim=0)

    if len(matches) > count:
        matches = np.random.choice(matches, size=count, replace=False).tolist()
    if len(matches) < count:
        return ds3(real_feats, gen_feats, count, used=used,n_clusters=n_clusters//2)
    return matches

def match_greedy(real_feats, gen_feats,count, distance='l2-near',zero_centered=False,\
                 text_feature=None, prune = False, gen_clip_features = None,using_clip_features=False,\
                    all_features_separately=None):
    real_feats = real_feats.copy()
    gen_feats = gen_feats.copy()
    if zero_centered:
        real_feats -= np.mean(real_feats,axis=0)
        gen_feats -= np.mean(gen_feats,axis=0)

    if using_clip_features:
        real_feats = real_feats / np.linalg.norm(real_feats, axis=1, keepdims=True)
        gen_feats = gen_feats / np.linalg.norm(gen_feats, axis=1, keepdims=True)
    if text_feature is not None:
        text_feature = text_feature / np.linalg.norm(text_feature)

    used = torch.zeros(len(gen_feats), dtype=torch.bool, device="cuda:0")

    if prune:
        removing_indices = prune_features_and_tensors_by_textsimilarity(gen_clip_features, text_feature)
        used[removing_indices] = True

    matches = []

    if distance == "random":
        matches = random_selection(gen_feats, count, used=used)
    elif distance == "Covariance_matching":
        matches = Covariance_matching(real_feats, gen_feats, count, [], pca_dim=32,used=used)

    elif distance == "Covariance_matching_nopca":
        matches = Covariance_matching(real_feats, gen_feats, count, [], pca_dim=None,used=used)

    elif distance == "Matching_alpha":
        matches = matching_alpha(real_feats, gen_feats, count, [], pca_dim=32,used=used)

    elif distance == "Matching_alpha_nopca":
        matches = matching_alpha(real_feats, gen_feats, count, [],pca_dim=None, used=used)

    elif distance == "Center_matching":
        matches = Center_matching(real_feats, gen_feats, count, used=used.cpu().numpy())

    elif distance == 'Text_matching':
        matches = Text_matching(real_feats, gen_feats, count, text_feature, used=used.cpu().numpy())

    elif distance == 'Center_sampling':
         matches = Center_sampling(real_feats, gen_feats, count, used=used.cpu().numpy())

    elif distance == 'Text_sampling':
         matches = Text_sampling(real_feats, gen_feats, count, text_feature, used=used.cpu().numpy())

    elif distance == 'K_mean':
         matches = K_mean(real_feats, gen_feats, count, used=used.cpu().numpy())

    elif distance == 'ds3':
        matches = ds3(real_feats, gen_feats, count, used=used,n_clusters=20)

    used = np.zeros(len(gen_feats), dtype=bool)

    used[matches] = True

    return matches,used
