import numpy as np
import math
import pandas as pd

from sklearn.datasets import fetch_openml
from sklearn.mixture import GaussianMixture
from itertools import combinations
from sklearn.decomposition import PCA
from scipy.stats import multivariate_t

import process_genomic_data

def add_adversarial_noise(data, epsilon, mode='uniform'):
    """
    Add a controllable amount of adversarial noise to the data.

    Parameters:
    - data: original data
    - epsilon: fraction of adversarial noise to add
    """
    num_samples, d = data.shape
    num_adversarial = int(epsilon * num_samples)
    
    if mode == 'uniform':
        # Adversarial noise generated from a uniform distribution spanning the range of the original data
        mins = data.min(axis=0)
        maxs = data.max(axis=0)
        adversarial_data = np.random.uniform(mins, maxs, size=(num_adversarial, d))
    elif mode == 'gaussian':
        # Adversarial noise generated from a Gaussian placed in the center of the data
        mean = np.mean(data, axis=0)
        cov = np.cov(data, rowvar=False)
        adversarial_data = np.random.multivariate_normal(mean, cov, num_adversarial)
    else:
        raise ValueError(f"Unknown mode '{mode}' for adversarial noise generation.")
    return np.vstack([data, adversarial_data])


def add_adversarial_clusters_line(data, cluster_centers, weights, r, epsilon, dist, cluster_cov=None):
    """
    Add adversarial clusters of weight 2w in a straight line, each at distance r from the previous.

    Parameters:
    - data: original data
    - cluster_centers: list of cluster centers
    - weights: list of cluster weights
    - r: distance between adversarial clusters
    - epsilon: fraction of adversarial noise to add
    - cluster_cov: covariance matrix of the min weight cluster (used for real world data)
    """
    num_samples, d = data.shape

    # Identify the smallest cluster by weight
    w = min(weights)
    smallest_cluster_index = weights.index(w)
    num_adversarial = int(2 * w * num_samples)  # Number of adversarial samples

    # Define a fixed direction (normalized)
    fixed_direction = np.random.randn(d)
    fixed_direction /= np.linalg.norm(fixed_direction)

    # Starting from the smallest cluster center
    current_center = cluster_centers[smallest_cluster_index]

    for i in range(int(math.ceil(epsilon/(2*w)))):
        # Move r distance away from the previous center in the fixed direction
        adversarial_center = current_center + fixed_direction * r * (i + 1)

        # Generate covariance matrix
        cov = generate_covariance_matrix(fixed_direction, 5, 1)

        # Normalize cov and scale it by the min weight cluster covariance norm
        if cluster_cov is not None:
            cluster_cov_norm = np.linalg.norm(cluster_cov)
            cov /= np.linalg.norm(cov)
            cov *= cluster_cov_norm
        
        # Generate adversarial cluster data
        if(dist == 't'):
            # Generate clusters from a multivariate t-distribution
            adversarial_data = multivariate_t(adversarial_center, df=5, shape=cov).rvs(num_adversarial)
        else:
            adversarial_data = np.random.multivariate_normal(adversarial_center, cov, num_adversarial)
        data = np.vstack([data, adversarial_data])
    
    return data


def generate_covariance_matrix(v, major_variance, minor_variance):
    """
    Generate a covariance matrix with large variance along v and smaller variance in the orthogonal directions.
    """
    d = len(v)
    v /= np.linalg.norm(v)  # Ensure v is a unit vector
    
    # Extend v to an orthogonal basis
    basis_vectors = [v] + [np.random.randn(d) for _ in range(d - 1)]
    orthogonal_basis = gram_schmidt(basis_vectors)

    # Create the scaling matrix
    scaling_matrix = np.diag([major_variance] + [minor_variance] * (d - 1))
    
    # Rotate the scaling matrix to align with the original v
    covariance_matrix = orthogonal_basis.T @ scaling_matrix @ orthogonal_basis

    return covariance_matrix


def gram_schmidt(vectors):
    """
    Perform Gram-Schmidt orthogonalization on a set of vectors.
    """
    basis = []
    for v in vectors:
        w = v - sum(np.dot(v, b)*b for b in basis)
        if (w > 1e-10).any():  # Check if w is not zero vector
            basis.append(w/np.linalg.norm(w))
    return np.array(basis)


def add_adversarial_clusters(data, cluster_centers, weights, r, epsilon, dist='gauss', cluster_cov=None):
    """
    Add circle of adversarial clusters of weight 2w at distance r from the smallest cluster.

    Parameters:
    - data: original data
    - cluster_centers: list of cluster centers
    - weights: list of cluster weights
    - r: distance between adversarial clusters
    - epsilon: fraction of adversarial noise to add
    - dist: distribution type of adversarial clusters ('gauss', 't', 'genomic_cov')
    - cluster_cov: covariance matrix of the min weight cluster (used for real world data)
    """
    num_samples, d = data.shape
    
    # Identify the smallest cluster by weight
    w = min(weights)
    smallest_cluster_index = weights.index(w)
    num_adversarial = int(2 * w * num_samples)  # Number of adversarial samples

    direction = np.random.randn(d)
    direction /= np.linalg.norm(direction)  # Normalize to get unit direction
    mu_c = cluster_centers[smallest_cluster_index] + direction * r
    
    for _ in range(int(math.ceil(epsilon/(2*w)))):
        # Generate a random direction and move r distance away from the smallest cluster center
        direction = np.random.randn(d)
        direction /= np.linalg.norm(direction)  # Normalize to get unit direction
        
        adversarial_center = mu_c + direction * r
    
        # Generate adversarial cluster data
        if(dist == 't'):
            # Generate clusters from a multivariate t-distribution
            adversarial_data = multivariate_t(adversarial_center, df=5).rvs(num_adversarial)
        elif(dist == 'gauss'):
            # Generate clusters from a multivariate Gaussian
            cov = np.eye(d)
            # Normalize cov and scale it by the min weight cluster covariance norm
            if cluster_cov is not None:
                cov *= np.linalg.norm(cluster_cov)
            adversarial_data = np.random.multivariate_normal(adversarial_center, cov, num_adversarial)
        elif(dist == 'genomic_cov'):
            # Generate cluster using the covariance matrix of the smallest cluster in the dataset
            cov = cluster_cov
            adversarial_data = np.random.multivariate_normal(adversarial_center, cov, num_adversarial)
        else:
            assert False, "Invalid distribution type for adversarial noise"

        data =  np.vstack([data, adversarial_data])
    return data


def load_noise_model(noise_model, S, true_centers, weights, r, epsilon, dist='gauss', cluster_cov=None):
    """
    Load the noise model and add adversarial noise to the data.

    Parameters:
    - noise_model: type of noise model to use ('gaussian', 'uniform', 'adv', 'adv_circle')
    - S: original data
    - true_centers: list of cluster centers
    - weights: list of cluster weights
    - r: distance between adversarial clusters
    - epsilon: fraction of adversarial noise to add
    - dist: distribution type of adversarial clusters ('gauss', 't', 'genomic_cov')
    """
    if noise_model == 'gaussian':
        S = add_adversarial_noise(S, epsilon, 'gaussian')
    elif noise_model == 'uniform':
        S = S.astype(np.float64)
        S = add_adversarial_noise(S, epsilon, 'uniform')
    elif noise_model == 'adv':
        S = add_adversarial_clusters_line(data=S, cluster_centers=true_centers, weights=weights, r=r, epsilon=epsilon, dist=dist, cluster_cov=cluster_cov)
    elif noise_model == 'adv_circle':
        S = add_adversarial_clusters(data=S, cluster_centers=true_centers, weights=weights, r=r, epsilon=epsilon, dist=dist, cluster_cov=cluster_cov)
    
    return S


def generate_separated_centers(num_clusters, separation, d=100):
    """
    Generate cluster centers that are separated by at least 'separation' distance.
    """
    centers = []
    
    # Start with a random center
    current_center = np.random.randn(d) * separation
    centers.append(current_center)
    max_dist = 0.
    min_dist = 10e20
    for _ in range(num_clusters - 1):
        # Generate a random direction
        incorrect_distance = True
        while incorrect_distance:
            direction = np.random.randn(d)
            direction /= np.linalg.norm(direction)  # Normalize to get unit direction
            
            # Compute next center
            next_center = current_center + direction * separation*2

            min_dist_ = min_dist
            max_dist_ = max_dist
            for c_ in centers:
                min_dist_ = min(min_dist_,np.linalg.norm(next_center - c_))
                max_dist_ = max(min_dist_,np.linalg.norm(next_center - c_))
            if min_dist_ > separation:
                centers.append(next_center)
                current_center = centers[np.random.choice(len(centers))]
                incorrect_distance = False
                min_dist = min_dist_
                max_dist = max_dist_
    
    print("Minimum Cluster distance: ",min_dist, "Max Cluster distance", max_dist,len(centers))
    return centers


def generate_mixture_data_with_separated_centers(num_samples=1000, separation=5, num_clusters=2, weights=None, d=100, dist='gauss'):
    """
    Generate data from a mixture of Gaussians with guaranteed separated cluster centers.
    """
    if weights is None:
        weights = [1/num_clusters] * num_clusters
    
    assert len(weights) == num_clusters, "Number of weights must match number of clusters."
    assert np.isclose(sum(weights), 1), "Weights must sum to 1."
    
    cluster_samples = [int(w * num_samples) for w in weights]
    
    # Generate separated cluster centers
    cluster_centers = generate_separated_centers(num_clusters, separation, d=d)
    
    data = []
    for i in range(num_clusters):
        mean = cluster_centers[i]
        cov = np.eye(d)
        if(dist == 't'):
            cluster_data = multivariate_t(mean, df=5).rvs(cluster_samples[i])
        elif(dist == 'gauss'):
            cluster_data = np.random.multivariate_normal(mean, cov, cluster_samples[i])
        else:
            assert False, "Invalid distribution type"
        print("99.9% of the data lies within ",np.quantile(np.linalg.norm(cluster_data- mean, axis=1),0.999))
        data.append(cluster_data)
    
    return np.vstack(data), cluster_centers


def get_genomic_cluster_cov(data, colors_centers, num_clusters):
    """
    Estimate the covariance matrix of the smallest true cluster in the genomic dataset.

    Parameters:
    - data: genomic data pandas dataframe with color/country column
    - colors_centers: list of cluster colors / country names sorted by weight (descending)
    - num_clusters: number of clusters to consider
    """
    eigenvalues = pd.read_csv(process_genomic_data.EIGEN_VALUES_FILE_PATH, header=None)
    eigenvalues = eigenvalues[:20].values

    data_list = list()
    for color in colors_centers[:num_clusters]:
        data_list.append(data.loc[data['color'] == color].drop(['color'], axis=1, inplace=False).to_numpy() * eigenvalues.squeeze())

    covs = [np.cov(data, rowvar=False) for data in data_list]
    return covs[-1]


def load_genomic_dataset(weight_min = 0.5, radius_quantile = 95):
    pca_data = pd.read_csv(process_genomic_data.PCA_SCORE_FILE_PATH, sep='\s+', header=None, skiprows=1)
    true_centers = pca_data.drop([pca_data.columns[0], pca_data.columns[22]], axis=1, inplace=False)
    data = pca_data.drop([pca_data.columns[0], pca_data.columns[1], pca_data.columns[22]], axis=1, inplace=False)

    ### Include color information
    ID_color_matching = process_genomic_data.read_ID_color_matches()
    colors = list()
    for id in pca_data[pca_data.columns[0]]:
        color_name = ID_color_matching.loc[ID_color_matching['ID'] == id, 'color'].values[0]
        colors.append(color_name)
    colors

    ### Merge centers and colors
    true_centers['color'] = colors
    true_centers = true_centers.drop([true_centers.columns[0]], axis=1, inplace=False)
    cov_estimation_data = true_centers

    ### empirical centers by color
    true_centers = true_centers.groupby('color').agg(['mean'])

    ### Count number of appearances of each color
    color_counts = ID_color_matching['color'].value_counts()
    color_counts = pd.DataFrame(color_counts)
    counts = list()
    for color in true_centers.index:
        count = color_counts.loc[color_counts.index == color, 'color'].values[0]
        counts.append(count)

    true_centers['weight'] = counts / sum(counts)

    ### sort true_centers by weight
    true_centers = true_centers.sort_values(by='weight', ascending=False)
    weights = true_centers['weight'].values
    colors_centers = true_centers.index
    true_centers.reset_index(drop=True, inplace=True)

    ### Reading eigenvalues, representing the variance explained by the corresponding principal component
    eigenvalues = pd.read_csv(process_genomic_data.EIGEN_VALUES_FILE_PATH, header=None)
    eigenvalues = eigenvalues[:20].values

    ### Scale pca_scores by eigenvalues to obtain dataset
    true_centers = true_centers.drop([true_centers.columns[20]], axis=1, inplace=False)
    true_centers = true_centers.to_numpy() * eigenvalues.squeeze()
    data = data.to_numpy() * eigenvalues.squeeze()

    ### Identify radius of clusters covering 95% of the data
    center_radius_dict = {color: [] for color in colors_centers}
    color_center_dict = {color: center for color, center in zip(colors_centers, true_centers)}

    for point, color in zip(data, colors):
        center = color_center_dict[color]
        dist = np.linalg.norm(point - center)
        center_radius_dict[color].append(dist)

    radii = [np.percentile(center_radius_dict[color], radius_quantile) for color in colors_centers]

    ### Convert to list of np arrays format
    true_centers_values = [np.array(center) for center in true_centers]

    ### Keep the clusters with weigth >= weight_min
    num_clusters = sum(weights >= weight_min)
    true_centers_values = true_centers_values[:num_clusters]
    weights = weights[:num_clusters]
    radii = radii[:num_clusters]

    max_dist = 0.
    min_dist = 10e20
    for i in range(len(true_centers)):
        for j in range(i+1, len(true_centers)):
            dist = np.linalg.norm(true_centers[i] - true_centers[j])
            max_dist = max(max_dist, dist)
            min_dist = min(min_dist, dist)
    print("Minimum Cluster distance: ",min_dist, "Max Cluster distance", max_dist)

    #for i, radius in enumerate(radii):
    #    print(f"Cluster {i+1}: weight {weights[i]} and radius (containing {radius_quantile}% of the data): {radius}")

    cluster_cov = get_genomic_cluster_cov(cov_estimation_data, colors_centers, num_clusters)

    return data, true_centers_values, weights, cluster_cov

def load_and_project_mnist_with_clustering(num_samples=1000, k=100, num_clusters=10, quantile=0.99):
    """Load MNIST data, perform PCA, and output matrices along with cluster centers."""
    # Load MNIST data
    mnist = fetch_openml('mnist_784', version=1)
    X, y = mnist['data'].to_numpy(), mnist['target'].to_numpy()
    
    # Randomly select a subset of data
    indices = np.random.choice(len(X), num_samples, replace=False)
    X_subset = X[indices]
    
    # Perform PCA
    pca = PCA(n_components=k)
    X_pca = pca.fit_transform(X_subset)
    
    # Output matrices
    print("Explained variance ratio:", pca.explained_variance_ratio_)
    print("99.9% of the data lies within", np.quantile(np.linalg.norm(X_pca, axis=1), 0.999))
    
    # Perform K-Means clustering
    kmeans = GaussianMixture(n_components=num_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(X_pca)
    cluster_centers = kmeans.means_
    
    # Print cluster centers
    print("Cluster Centers:")
    for i, center in enumerate(cluster_centers):
        print(f"Cluster {i+1}:", center)
    
    # Calculate the radius for each cluster containing 99% of the data
    cluster_radii = []
    for i in range(num_clusters):
        distances = np.linalg.norm(X_pca[cluster_labels == i] - cluster_centers[i], axis=1)
        radius = np.quantile(distances, quantile)
        cluster_radii.append(radius)
        print(f"Cluster {i+1} radius (containing 99% of the data): {radius}")

     # Calculate min and max distances between cluster centers
    distances = [np.linalg.norm(pair[0] - pair[1]) for pair in combinations(cluster_centers, 2)]
    min_distance = min(distances)
    max_distance = max(distances)
    
    print("Minimum distance between cluster centers:", min_distance)
    print("Maximum distance between cluster centers:", max_distance)
    
    return X_pca, cluster_centers