#Code that generates Figure 3 of the paper

import matplotlib
# The line below is often used in notebooks/headless environments, 
# but for typical Python script running with a GUI, it might be removed 
# or changed to 'TkAgg' or similar if interactive plotting is desired.
# Keeping it as 'Agg' ensures the script can run and save the plot without a display.
matplotlib.use('Agg')
import numpy as np
import math
import os
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA

# Remove Jupyter-specific line for print options and inline plotting
np.set_printoptions(precision=3)
# %matplotlib inline # Removed as it is a Jupyter magic command

# --- Function Definitions ---

def generate_points_mixture0(dim, L, sigma, z0, z1):
    """
    Generates L points from a mixture of two Normal distributions.
    z0 and z1 (the centroids) are passed as arguments now instead of 
    relying on them being globally defined after the function definition.
    """
    # Define parameters
    I = np.eye(dim)
    cov_z0 = sigma**2 * I
    cov_z1 = sigma**2 * I

    # Initialize points
    points = np.zeros((L,dim))
    centroids = np.zeros((L,dim))
    clusters = np.zeros(L)

    # Generate L points from the mixture
    for i in range(L):
        # Choose a distribution (0 for z0, 1 for z1) with equal probability
        choice = np.random.choice([0, 1], p=[0.5, 0.5])
        if choice == 0:
            # Sample from Normal(z0, cov_zo)
            point = np.random.multivariate_normal(z0, cov_z0)
            centroid = z0
            cluster = 0
        else:
            # Sample from Normal(z1, cov_z1)
            point = np.random.multivariate_normal(z1, cov_z1)
            centroid =z1
            cluster = 1
        points[i]=point
        centroids[i]=centroid
        clusters[i]=cluster
    return points, centroids, clusters

def c1(n,sigma=0.1):
    return 1+n*sigma**2

def c2(n,sigma=0.1,dim=5):
    return 1+sigma**2 * (dim+n)

def optimalT(L=100,sigma=0.1):
    # Variables c1(n) and c2(n) rely on the default sigma/dim, but 
    # the notebook doesn't specify which ones, so we'll assume the defaults.
    # If the user intended to use the global 'dim' or 'sigma' from the 
    # execution block, those would need to be passed as arguments.
    C1 = (2/L)*c2(4)+((L-1)/L)*c1(4)
    C2 = (16 *sigma**2/(L**2))*c2(6)+(8*sigma**2 * (L-1)/(L**2))*c1(6)+(4 *sigma**2 *(L-1)/(L**2))*c2(3)+(sigma**2 *(L-1)*(L-2)/(L**2))*c1(6)
    A0= (2/(L**2))*c2(8)+(2 *(L-1)/(L**2))*c1(5)+((L-1)/(L**2))*c2(4)+((L-1)*(L-2)/(2*L**2))*c1(4)
    C0= 4*sigma**2 *(L-1)/(L**2)
    C3 = 2*A0+C0
    return C1/(C2+C3)

# The following two functions depend on a global variable 'L' and 'T' 
# from the execution block in the notebook. They must be passed as arguments 
# for the script to be properly self-contained.

#Compute attention head
def H(dim, m, T, points, L):
    M=np.zeros((L,dim))
    for l in range(L):
        a=np.dot(m,points[l])
        res=np.zeros((1,dim))
        for k in range(L):
            b=np.dot(m,points[k])
            # linear computation based on notebook: T*a*b*points[k]
            res += T*a*b*points[k] 
        res=(2/L)*res
        M[l]=res
    return M  

#Compute theoretical risk with expectation
def pred(dim, m0, m1, points, T, L):
    # H_m0 and H_m1 call H, which now needs L and T
    H_m0 = H(dim, m0, T, points, L)  # Sample from H^mu0
    H_m1 = H(dim, m1, T, points, L)  # Sample from H^mu1
    H_sum = H_m0 + H_m1  # Compute the sum of H^mu0 and H_mu1
    return H_sum

# --- Main Execution Block ---

# Parameters
dim = 10
sigma = 0.3

# Centroids (mu_0 and mu_1)
z0 = np.zeros(dim)
z0[dim-1] = 1
z1 = np.zeros(dim)
z1[0] = -1

L = 500  # Total number of points

# Generate points, using z0 and z1 as arguments
resu = generate_points_mixture0(dim, L, sigma, z0, z1)
points = resu[0]
centroids = resu[1] # Not used later, but kept for completeness
clusters = resu[2] # Not used later, but kept for completeness

# Determine the temperature T

# Based on the code, we use the default for optimalT().
optimal = optimalT(L,sigma) 
T = optimal

# The commented-out lines below show other T values that were tested in the notebook:
# T = 1.3 
# T = 2 
# T = 10

actual = points
# predictions calls pred, which now needs T and L
predictions = pred(dim, z0, z1, points, T, L)

# PCA
pca = PCA(n_components=2)
actual_2d = pca.fit_transform(actual)
predictions_2d = pca.transform(predictions)

# --- Plotting ---

# Create directory for results
os.makedirs('cloud_points/results', exist_ok=True)

# Set font properties
font = {'family' : 'sans-serif',
        'size'   : 26}
plt.rc('font', **font) # Use plt.rc to set global font properties

plt.figure(figsize=(30, 20))
sns.set_context("notebook", font_scale=4)

# Plot actual tokens
plt.scatter(actual_2d[:, 0], actual_2d[:, 1], marker='o',
            facecolors="None", edgecolors='deepskyblue',
            label="Input tokens", 
            alpha=1, zorder=1, s=700, linewidth=4)

# Plot transformed tokens
plt.scatter(predictions_2d[:, 0], predictions_2d[:, 1], marker='o',
            color='dodgerblue',
            label="Transformed tokens", 
            alpha=0.7, zorder=1, s=700, linewidth=4)

# Plot lines connecting actual to transformed
for i in range(L):
    plt.plot([actual_2d[i, 0], predictions_2d[i, 0]],
             [actual_2d[i, 1], predictions_2d[i, 1]],
             color="gray", alpha=0.5)

# Transform and plot centroids (mu_0 and mu_1)
z02d = pca.transform([z0])
z12d = pca.transform([z1])

plt.scatter(z02d[0,0], z02d[0,1], marker='*', 
            color='darkorange', 
            s=900, label='$\\mu_0^\\star, \\mu_1^\\star$', zorder=3, linewidth=4)

plt.scatter(z12d[0,0], z12d[0,1], marker='*', 
            color='darkorange', 
            s=900, zorder=4, linewidth=4)

# Set labels and grid
# plt.title("Comparison between input/transformed tokens") # Title commented out in notebook
plt.xlabel("PCA Component 1")
plt.ylabel("PCA Component 2")
plt.legend()
plt.grid()

# Save the figure
plt.savefig("cloud_points/results/cloud_points.pdf")

# plt.show() # Uncomment this line to display the plot if running in an interactive environment

plt.close() # Close the figure to free up memory