import torch
import numpy as np
from scipy.stats import wasserstein_distance
from sklearn.kernel_approximation import RBFSampler

"""
WMAD: Compute the Wasserstein distance for any two agents
"""

# Compute the Wasserstein distance using the dual form
def wasserstein_distance_dual(agent1_representations, agent2_representations):
    num_samples1, _ = agent1_representations.shape
    num_samples2, _ = agent2_representations.shape

    # Calculate pairwise distances between samples
    distances = np.zeros((num_samples1, num_samples2))
    for i in range(num_samples1):
        for j in range(num_samples2):
            distances[i, j] = np.linalg.norm(agent1_representations[i] - agent2_representations[j])

    # Compute the Wasserstein distance
    cost_matrix = distances / np.median(distances)
    wasserstein_distance_value = wasserstein_distance([], [], cost_matrix=cost_matrix)
    return wasserstein_distance_value


def wasserstein_distance_dual_rff(agent_representations1, agent_representations2, gamma, num_features, lambda_mu,
                                  lambda_nu, G, b, m, beta=0.05):
    # Compute the random feature maps
    phi_agent_representations1 = compute_random_feature_map(agent_representations1, G, b, m)
    phi_agent_representations2 = compute_random_feature_map(agent_representations2, G, b, m)

    # Compute the differences and the exponential term
    difference = np.dot(phi_agent_representations1, lambda_mu) - np.dot(phi_agent_representations2, lambda_nu)
    exponential_term = np.exp((difference - gamma) / beta)

    # Compute the Wasserstein distance
    wasserstein_distance = np.mean(difference - gamma * exponential_term)

    return wasserstein_distance


# Dual function to be updated
def dual_function(agent_representations, gamma, num_features):
    num_samples, _ = agent_representations.shape
    rff_sampler = RBFSampler(gamma=gamma, n_components=num_features, random_state=42)
    rff_sampler.fit(agent_representations)
    random_features = rff_sampler.transform(agent_representations)
    dual_values = np.dot(random_features, np.random.randn(num_features))
    return dual_values


def compute_gradient(phi_kappa_c1, phi_ell_c2, lambda_mu, lambda_nu, euclidean_distance, beta=0.05):
    # Compute the difference between the feature mappings weighted by coupling parameters
    difference = (lambda_mu.dot(phi_kappa_c1)) - (lambda_nu.dot(phi_ell_c2))

    # Compute the exponential term
    exponent = difference - euclidean_distance
    exponential_term = 1 - np.exp(exponent / beta)

    # Compute the gradient for lambda_mu and lambda_nu
    gradient_lambda_mu = exponential_term * phi_kappa_c1
    gradient_lambda_nu = -exponential_term * phi_ell_c2

    # Concatenate the gradients
    gradient = np.concatenate((gradient_lambda_mu, gradient_lambda_nu))

    return gradient


def calculate_min_wasserstein_distances(all_agents_representations, gamma, num_features, learning_rate=0.01,
                                        num_iterations=100, beta=1.0):
    num_agents = len(all_agents_representations)
    min_distances = np.zeros(num_agents)

    # Generate random parameters for the random feature map
    m, d = num_features, all_agents_representations[0].shape[0]
    G = np.random.normal(0, 1, size=(m, d))
    b = np.random.uniform(0, 2 * np.pi, size=m)

    # Initialize coupling parameters
    lambda_mu = np.random.randn(num_features)
    lambda_nu = np.random.randn(num_features)

    # Iterate over all agent representations
    for i, current_agent_representations in enumerate(all_agents_representations):
        # Update dual functions for current agent
        for _ in range(num_iterations):
            for j, other_agent_representations in enumerate(all_agents_representations):
                if i != j:#compute_random_feature_map(x, G, b, m)
                    phi_kappa_c1 = compute_random_feature_map(current_agent_representations, G, b, m)
                    phi_ell_c2 = compute_random_feature_map(other_agent_representations, G, b, m)
                    euclidean_distance = np.linalg.norm(current_agent_representations - other_agent_representations)
                    gradient_lambda = compute_gradient(phi_kappa_c1, phi_ell_c2, lambda_mu, lambda_nu,
                                                       euclidean_distance, beta)
                    lambda_mu += learning_rate * gradient_lambda[:len(lambda_mu)]  # Update lambda_mu
                    lambda_nu += learning_rate * gradient_lambda[len(lambda_mu):]  # Update lambda_nu

        # Calculate minimum Wasserstein distance for current agent
        min_distance = np.inf
        for j, other_agent_representations in enumerate(all_agents_representations):
            if i != j:
                distance = wasserstein_distance_dual_rff(current_agent_representations, other_agent_representations,
                                                         gamma, num_features, lambda_mu, lambda_nu, G, b, m, beta)
                if distance < min_distance:
                    min_distance = distance

        min_distances[i] = min_distance

    return min_distances




def compute_random_feature_map(x, G, b, m):
    phi_x = np.sqrt(2 / m) * np.cos(G.dot(x) + b[:, np.newaxis])

    return phi_x