import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
import os

from torch.nn.functional import linear

# Setup device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Sampling function for normal distribution

def sample_normal(mu, sigma, dim, count):
    """
    Samples 'count' points from a multivariate normal distribution N(mu, sigma^2 I_d).

    Args:
        mu (torch.Tensor): Mean vector of size (d,).
        sigma (float): Standard deviation of the distribution.
        dim (int): Dimensionality d.
        count (int): Number of samples to generate.

    Returns:
        torch.Tensor: Samples of shape (count, dim) on the device (CPU/GPU).
    """
    mu = mu.to(device)
    Z = torch.randn(count, dim, device=device)
    samples = Z * sigma + mu
    return samples

def construct_A(mu: torch.Tensor, k: int) -> torch.Tensor:
    """
    Constructs a matrix A ∈ R^{k x d} given input mu ∈ R^d,
    such that AA^T = I_k and ||A mu|| = ||mu||.
    """
    d = mu.shape[0]
    assert 1 <= k < d, "k must satisfy 1 ≤ k < d"
    assert mu.ndim == 1, "mu must be 1D"

    # Normalize mu
    u = mu / torch.norm(mu)
    A = [u]

    while len(A) < k:
        v = torch.randn(d, device=mu.device)
        for a in A:
            v -= torch.dot(v, a) * a
        v_norm = torch.norm(v)
        if v_norm < 1e-10:
            continue
        A.append(v / v_norm)

    return torch.stack(A, dim=0)  # shape (k, d)

# Compute Efficiency function

def Compute_Efficiency(gamma, S, k, d):
    """
    Computes the efficiency of dimensionality reduction from d to 1 <= k < d 
    on binary Gaussian Mixture Model (GMM) data. 

    Parameters:
    - S = ||mu||^2 / sigma^2: the signal-to-noise ratio (SNR) of the GMM data.
    - gamma = N2 / N1: the class imbalance factor of the training set.

    Returns:
    - Efficiency gain (%) for each training sample size N_train.
    """
    sigma = 1

    X = 100  # Number of repetitions for averaging
    prec_list = []

    start = 0
    end = 10000
    num_points = 150
    step = (end - start) / (num_points - 1)

    n_values = [start + i * step for i in range(num_points)]

    for i in range(num_points):
        print(i)
        N = int(start + i * step)

        if N == 0:
            prec_list.append(0.0)
            continue
        
        N_1 = int(N/(1+gamma))
        N_2 = int(gamma * N /(1+gamma))

        # Randomize mu_d in R^d with ||mu_d|| = sigma * sqrt(S):
        v_d = torch.randn(d)  # v_d ~ N(0, I_d)
        mu_d = sigma * np.sqrt(S) * (v_d / torch.norm(v_d))

        A = construct_A(mu_d, k)
        A = A.to(device)

        samples_test_count = 50000
        gains = []

        with torch.no_grad():
            for _ in range(X):

                # Sample noisy train data
                train_samples_1 = sample_normal(-mu_d, sigma, d, N_1)
                train_samples_2 = sample_normal(mu_d, sigma, d, N_2)

                mu_hat_1 = torch.mean(train_samples_1, dim=0)
                mu_hat_2 = torch.mean(train_samples_2, dim=0)

                test_samples_1 = sample_normal(-mu_d, sigma, d, samples_test_count)
                test_samples_2 = sample_normal(mu_d, sigma, d, samples_test_count)

                error_count_1 = torch.sum(torch.norm(test_samples_1 - mu_hat_2, dim=1) < torch.norm(test_samples_1 - mu_hat_1, dim=1)) / samples_test_count
                error_count_2 = torch.sum(torch.norm(test_samples_2 - mu_hat_1, dim=1) < torch.norm(test_samples_2 - mu_hat_2, dim=1)) / samples_test_count

                error_noisy = (error_count_1 + error_count_2) / 2

                # Sample embedded train data
                
                train_samples_1_embedded = linear(train_samples_1, A)
                train_samples_2_embedded = linear(train_samples_2, A)
                            
                mu_hat_1_embedded = torch.mean(train_samples_1_embedded, dim=0)
                mu_hat_2_embedded = torch.mean(train_samples_2_embedded, dim=0)

                test_samples_1_embedded = linear(test_samples_1, A)
                test_samples_2_embedded = linear(test_samples_2, A)

                error_count_1_embedded = torch.sum(torch.norm(test_samples_1_embedded - mu_hat_2_embedded, dim=1) < torch.norm(test_samples_1_embedded - mu_hat_1_embedded, dim=1)) / samples_test_count
                error_count_2_embedded = torch.sum(torch.norm(test_samples_2_embedded - mu_hat_1_embedded, dim=1) < torch.norm(test_samples_2_embedded - mu_hat_2_embedded, dim=1)) / samples_test_count

                error_embedded = (error_count_1_embedded + error_count_2_embedded) / 2

                # Gain calculation
                gain = (error_noisy - error_embedded) / error_noisy * 100

                gains.append(gain)

        # Average over X runs
        prec_list.append(torch.mean(torch.tensor(gains, device=device)).item())

    return prec_list, n_values


# Plot precision gain function

def plot_precision_gain(S, k, d, gamma_values, efficiency_dict, saving_path):
    """
    Plots theoretical and simulated efficiency gains (%) versus training sample size.
    """

    # Q-function
    def Q(x):
        return 1 - norm.cdf(x)

    # Error formula for "noisy" setup (original GMM in R^d):
    def error_noisy(x, g):
        A = (1 / (4 * x)) * ((1 - g) / g) * (d / np.sqrt(S))
        B = (1 / (4 * x)) * ((1 + g) / g) * (d / S)
        C = (1 / (8 * x**2)) * ((1 + g**2) / g**2) * (d / S)
        D1 = (1 / (g * x)) + 1
        D2 = (1 / x) + 1

        num1 = np.sqrt(S) + A
        num2 = np.sqrt(S) - A

        denom1 = np.sqrt(B + C + D1)
        denom2 = np.sqrt(B + C + D2)

        return 0.5 * (Q(num1 / denom1) + Q(num2 / denom2))

    # Error formula for the embedded setup (After dimentionality reduction to R^k):
    def error_embedded(x, g):
        A = (1 / (4 * x)) * ((1 - g) / g) * (k / np.sqrt(S))
        B = (1 / (4 * x)) * ((1 + g) / g) * (k / S)
        C = (1 / (8 * x**2)) * ((1 + g**2) / g**2) * (k / S)
        D1 = (1 / (g * x)) + 1
        D2 = (1 / x) + 1

        num1 = np.sqrt(S) + A
        num2 = np.sqrt(S) - A

        denom1 = np.sqrt(B + C + D1)
        denom2 = np.sqrt(B + C + D2)

        return 0.5 * (Q(num1 / denom1) + Q(num2 / denom2))

    # Plotting setup

    colors = plt.get_cmap('tab10')
    markers = {
        0.25: 's',   # Square
        0.5: '^',    # Triangle Up
        1: 'o'       # Circle
    }

    x_vals = np.linspace(0.01, 10000, 500)

    fig = plt.figure(figsize=(10, 6))
    for idx, g in enumerate(gamma_values):
        adjusted_x_vals = (1/(1+g)) * x_vals

        # Calculate noisy and embedded error values using the functions defined above

        noisy_error_vals = error_noisy(adjusted_x_vals, g)
        embedded_error_vals = error_embedded(adjusted_x_vals, g)
        prec = (noisy_error_vals - embedded_error_vals) / noisy_error_vals * 100

        plt.plot(x_vals, prec, label=f'γ = {g} (theory)', color=colors(idx), linewidth=2)

        # Plot provided points
        efficiency_values = efficiency_dict[g][0]
        curr_N_vals = efficiency_dict[g][1]

        plt.scatter(curr_N_vals, efficiency_values, color=colors(idx), edgecolor='black',
                    label=f'γ = {g} (simulation)', zorder=5, marker=markers[g])

    plt.xlabel('N_train', fontsize=16)    
    plt.ylabel('Efficiency (%)', fontsize=16)  
    plt.xticks(fontsize=14)              
    plt.yticks(fontsize=14)                
    plt.legend(title='Gamma values', loc='best', frameon=False, fontsize=12, title_fontsize=13)
    plt.grid(True)
    plt.tight_layout()

    # Save and show plot
    save_and_show_plot(fig, saving_path)


# Function to save and display the plot

def save_and_show_plot(fig, filename):

    dir_path = os.path.dirname(filename)
    if dir_path:
        os.makedirs(dir_path, exist_ok=True)
    
    fig.savefig(filename)
    plt.show()

# Only works for gamma in {0.25, 0.5, 1} because of the markets (but that is what we need)
gamma_list = [0.25, 0.5, 1]  

# k,d values - Can modify
k = 1000
d = 2000

# S values - Can modify
S_list = [0.75 * 0.75, 1, 1.5 * 1.5]

Final_dict = {}

for S in S_list:
    path = "" # Enter your wanted path
    print(f"Computing efficiencies for S = {S}...")
    for gamma in gamma_list:
        print(f"Processing Gamma = {gamma}...")
        curr_prec_list, curr_N_values = Compute_Efficiency(gamma, S, k, d)
        Final_dict[gamma] = (curr_prec_list, curr_N_values)

    print("Plotting precision gain...")
    plot_precision_gain(S, k, d, gamma_list, Final_dict, path)
