import torch
from matplotlib import pyplot as plt

from src.models.modules import MixtureGaussian, MultiChannelMixtureGaussian

# Constants for the tests
K = 5
T = 100
DIM_Y_MO = 3
DIM_MODEL = 32
DIM_FEEDFORWARD = 64


MGSO = MixtureGaussian(
    dim_y=1,
    dim_model=DIM_MODEL,
    dim_feedforward=DIM_FEEDFORWARD,
    num_components=K,
)

MGMO = MultiChannelMixtureGaussian(
    dim_y=DIM_Y_MO,
    dim_model=DIM_MODEL,
    dim_feedforward=DIM_FEEDFORWARD,
    num_components=K,
)


X = torch.linspace(0, 1, T).view(1, T, 1)
# get some cos func as means, e.g. cos(ax + b)
a = torch.rand(1, 1, K, DIM_Y_MO)*10 + 1
b = torch.rand(1, 1, K, DIM_Y_MO)*torch.pi

means = torch.cos(
    a * X.unsqueeze(-2) + b
) # [1, T, K, DIM_Y_MO]
stds = 0.001 + 0.2 * torch.randn(
    1, 1, K, DIM_Y_MO
).abs().expand(
    1, T, K, DIM_Y_MO
) # [1, T, K, DIM_Y_MO]
weights_so = torch.randn(
    1, 1, K, 1
).softmax(dim=-2).expand(1, T, K, 1)
weights_mo = torch.randn(
    1, 1, K, DIM_Y_MO
).softmax(dim=-2).expand(1, T, K, DIM_Y_MO)

# sample points from the GMMs
samples_so = MGSO._sample_mixture(
    MGSO._flat(means[..., :1]),
    MGSO._flat(stds[..., :1]),
    MGSO._flat(weights_so),
    num_sample=1
).permute(1, 0, 2).view(1, T, 1, 1) # [1, T, 1, 1]
samples_mo = MGMO._sample(
    means, stds, weights_mo, num_samples=1
) # [1, T, 1, DIM_Y_MO]

# visualize
fig, axs = plt.subplots(2, DIM_Y_MO, squeeze=False, figsize=(5*DIM_Y_MO, 3*2))
for k in range(K):
    axs[0, 0].plot(X.squeeze(), means[0, :, k, 0], '-', color=f'C{k}', label=f'Comp {k+1}')
    axs[0, 0].fill_between(
        X.squeeze(),
        means[0, :, k, 0] - stds[0, :, k, 0],
        means[0, :, k, 0] + stds[0, :, k, 0],
        color=f'C{k}', alpha=0.2
    )
axs[0, 0].plot(X.squeeze(), samples_so[0, :, :, 0], 'x', color='black', label='Sample')
axs[0, 0].set_title(f'Single-Channel Mixture Gaussian\nWeights {weights_so[0, 0, :, 0]}')
axs[0, 0].legend()

for dy in range(DIM_Y_MO):
    for k in range(K):
        axs[1, dy].plot(X.squeeze(), means[0, :, k, dy], '-', color=f'C{k}', label=f'Comp {k+1}')
        axs[1, dy].fill_between(
            X.squeeze(),
            means[0, :, k, dy] - stds[0, :, k, dy],
            means[0, :, k, dy] + stds[0, :, k, dy],
            color=f'C{k}', alpha=0.2
        )
    axs[1, dy].plot(X.squeeze(), samples_mo[0, :, :, dy], 'x', color='black', label='Sample')
    axs[1, dy].set_title(f'Multi-Channel Mixture Gaussian; Dim {dy+1}\nWeights {weights_mo[0, 0, :, dy]}')
axs[1, 0].legend()
plt.show()
