from argparse import Namespace
import torch
from dataset_generation.latent_generator import LatentGenerator


args = Namespace(
    N_theta=1,
    N_mod=2,
    alpha=0.9,
    beta=0.9,
    eta_arg=[10.0],
    rho_theta_arg=[0.8],
    rho_arg=[1, 1],
    generation_paradigm="constructive"
)


args2 = Namespace(
    N_theta=1,
    N_mod=2,
    alpha=0.9,
    beta=0.4,
    eta_arg=[5.0],
    rho_theta_arg=[0.8],
    rho_arg=[1, 0.8],
    generation_paradigm="constructive"
)

mod_dim = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
latent_generator = LatentGenerator(args, mod_dim, device)

A = latent_generator.get_A_matrix()
Sigma = latent_generator.get_sigma()

# Sigma = torch.randn(4, 3, 2)
Sigma_rank = torch.linalg.matrix_rank(Sigma)
Sigma_eig = torch.linalg.eigvals(Sigma)
Sigma_det = torch.linalg.det(Sigma)
Sigma_eig_cpu = Sigma_eig.cpu().numpy()
is_positive_definite = latent_generator.is_positive_definite(Sigma)
Sigma_cpu = Sigma.cpu().numpy()
A_cpu = A.cpu().numpy()



# Get standard deviations from the diagonal
std_dev = torch.sqrt(torch.diag(Sigma))
# Build diagonal matrix with 1 / std on the diagonal
D_inv_sqrt = torch.diag(1.0 / std_dev)
# Correlation matrix: R = D^{-1/2} @ Sigma @ D^{-1/2}
Sigma_corr = D_inv_sqrt @ Sigma @ D_inv_sqrt

Sigma_corr_cpu = Sigma_corr.cpu().numpy()


mvn = torch.distributions.MultivariateNormal(latent_generator.get_mu(), covariance_matrix=Sigma)

print(Sigma_rank)

####### Plotting Covariance Matrix ########

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # needed for projection='3d'

# ---- PARAMETERS ----
num_samples = 100000  # how many points to draw in each case


latent_generator2 = LatentGenerator(args2, mod_dim, device)

# ---- DRAW SAMPLES ----
mvn_samples = mvn.sample((num_samples,))        # (num_samples, 3)
mvn_samples = mvn_samples.cpu().numpy()
constructive_samps = latent_generator2.sample(num_samples)  # (num_samples, 3)
constructive_samps = constructive_samps.cpu().numpy()

# ---- PLOTTING 3 SUBPLOTS ----
fig, axs = plt.subplots(1, 3, figsize=(18, 5))

pairs = [(0, 1), (1, 2), (0, 2)]
titles = ['Dim 1 vs Dim 2', 'Dim 2 vs Dim 3', 'Dim 1 vs Dim 3']

for ax, (i, j), title in zip(axs, pairs, titles):
    # MVN samples
    ax.scatter(
        mvn_samples[:, i], mvn_samples[:, j],
        s=10, alpha=0.5, label='MVN Sampling'
    )
    # Constructive samples
    ax.scatter(
        constructive_samps[:, i], constructive_samps[:, j],
        s=10, alpha=0.5, label='l = Au', color='orange'
    )
    ax.set_xlabel(f'Dim {i+1}')
    ax.set_ylabel(f'Dim {j+1}')
    ax.set_title(title)

# Only need one legend for all three
axs[2].legend(loc='best')
Title = "Sampling from the MVN vs Constructive Latent Generation (Diff. Sigma)"
plt.suptitle(Title, fontsize=22)
plt.tight_layout()
plt.savefig(f"output_dir/{Title}.png")
plt.show()
