import numpy as np
import torch

from matplotlib import pyplot as plt
from src.models.benchmarks.modules import (BarDistribution, FullSupportBarDistribution, get_bucket_borders)

Batch_size = 1
Nc, Nt = 1000, 200 # number of contexts, number of targets
Nbars = 20 # number of buckets / num of bins of pfn Riemann distribution

# Generate prior data to determine bucket borders
y_prior = torch.randn([1000]).clamp(-2, 2)
# Get bucket borders
bucket_borders = get_bucket_borders(Nbars, ys=y_prior)
# Create histogram distribution
bar_dist = BarDistribution(bucket_borders)
fs_bar_dist = FullSupportBarDistribution(bucket_borders)

random_logits = torch.randn([Batch_size, Nbars]) # 10 distributions with Nbars logits each
y_bar_true = torch.linspace(-2, 2, 1000).unsqueeze(0).repeat(Batch_size, 1) # (Batch_size, 1000)
#y_bar_true = bar_dist.sample(random_logits, num_samples=Nt) # (Batch_size, Nt)
y_fs_bar_true = torch.linspace(-3, 3, 1000).unsqueeze(0).repeat(Batch_size, 1) # (Batch_size, 1000)
#y_fs_bar_true = fs_bar_dist.sample(random_logits, num_samples=Nt) # (Batch_size, Nt)

fig, ax = plt.subplots(figsize=(10, 5))

# Plot the bar distribution without tails
densities = bar_dist.log_prob(random_logits, y_bar_true).exp()
ax.plot(
    y_bar_true.numpy().T,
    densities.numpy().T,
    '-',
    alpha=0.5,
    label='without tails'
)
# Plot the bar distribution with half-normal tails
densities = fs_bar_dist.log_prob(random_logits, y_fs_bar_true).exp()
ax.plot(
    y_fs_bar_true.numpy().T,
    densities.numpy().T,
    '-',
    alpha=0.5,
    label='with half-normal tails'
)
# 
ax.set_title("Bar Distribution")
ax.axhline(0, color='k', linewidth=1.0, linestyle="-", alpha=0.5)
ax.set_xlabel("y values")
ax.set_ylabel("Density")
ax.legend()

plt.show()


