from sklearn.model_selection import train_test_split
import matplotlib as mpl
import scipy
import seaborn as sns
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader

from aux import *

mpl.use('TkAgg')

length_ticks = 3
font_size = 10
linewidth = 1.2
scatter_size = 0.5
horizontal_size = 5
vertical_size = 5

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize'] = font_size
mpl.rcParams['ytick.labelsize'] = font_size
mpl.rcParams['lines.linewidth'] = linewidth

# Circle case
mean=[50,30]
cov=[[2, 0], [0, 2]]


# Don't give sample in this range (rotated by a \pi/4 angle)
left_reward = 49.5
right_reward = 50.5
angle = 0
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])

# Generate data
nPix = 20  # size of gratings
images, rewards, orientations = make_incomplete_gaussian_gratings_rotate(1000, 1, nPix, 4, left_reward, right_reward,
                                                                         mean, cov, rotation_matrix)

# Split in train and test
images_train, images_test, rewards_train, rewards_test, orientations_train, orientations_test = train_test_split(images,
                                                                                                                 rewards,
                                                                                                                 orientations,
                                                                                                                 test_size=.33,
                                                                                                random_state=26)
batch_size = 200

# Instantiate training and test data
train_data = Data(images_train, rewards_train, orientations_train)
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

test_data = Data(images_test, rewards_test, orientations_test)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

n_neurons_each_dim = 10
n_neurons = n_neurons_each_dim * n_neurons_each_dim
input_dim = nPix * nPix

# Quantile levels
taus = np.linspace(1.0 / n_neurons_each_dim, 1, n_neurons_each_dim)
taus = torch.tensor(taus)

# True joint
_, rewards_true, orientations_true = make_gaussian_gratings(1000, 1, nPix, 4, mean, cov)
n_bins = 20  # 20
reward_flat = np.linspace(40, 60, n_bins + 1)
orientation_flat = np.linspace(20, 40, n_bins + 1)
true_joint, _, _ = np.histogram2d(rewards_true, orientations_true, [reward_flat, orientation_flat])
true_joint = true_joint / np.sum(true_joint)
true_joint = np.transpose(true_joint)

# Consider minus one bin
reward_flat = np.linspace(40, 60, n_bins)
orientation_flat = np.linspace(20, 40, n_bins)
reward_mesh, orientation_mesh = np.meshgrid(reward_flat, orientation_flat)
pos = np.dstack((reward_mesh, orientation_mesh))

n_runs = 10
kl_runs = []
std = 2  # For estimating joint probability distribution
learning_rate = 0.0001  # 0.0005
num_epochs = 1000

# Generate gratings to test quantile position (set the orientations = tau quantiles)
wDeg = 1  # size of image (in degrees)
sf = 4
x, y = np.meshgrid(np.linspace(-wDeg / 2, wDeg / 2, nPix), np.linspace(-wDeg / 2, wDeg / 2, nPix))
orientations_unif = np.quantile(orientations_train, taus)
n_samples = len(orientations_unif)
orientations = np.expand_dims(orientations_unif, axis=(1, 2))
O = np.full((n_samples, nPix, nPix), orientations)  #
ramp = np.sin(O * np.pi / 180) * x - np.cos(O * np.pi / 180) * y
gratings = np.sin(2 * np.pi * sf * ramp)
gratings = (gratings + 1) / 2
gratings = torch.from_numpy(gratings.astype(np.float32))

for run in range(n_runs):

    model = nn.Sequential(nn.Flatten(),
                          nn.Linear(input_dim, 1024),
                          nn.ReLU(),
                          # nn.Linear(512, 512),
                          # nn.ReLU(),
                          # nn.Linear(512, 512),
                          # nn.ReLU(),
                          nn.Linear(1024, 256),
                          nn.ReLU(),
                          nn.Linear(256, n_neurons_each_dim))

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss_values = []

    for epoch in range(num_epochs):

        for images_e, rewards_e, orientation_e in train_dataloader:
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(images_e)

            loss = dist_loss(pred, rewards_e.unsqueeze(-1), taus)
            loss_values.append(loss.item())
            loss.backward()
            optimizer.step()

    # Get estimated particles
    pred_after = model(gratings)
    pred_after = pred_after.detach().numpy()

    # Create particles array
    particles_x_after = np.zeros(n_neurons_each_dim * n_neurons_each_dim)
    particles_y_after = np.zeros(n_neurons_each_dim * n_neurons_each_dim)
    for neu in range(n_neurons_each_dim - 1):
        particles_x_after[n_neurons_each_dim * neu:n_neurons_each_dim * (neu + 1)] = pred_after[neu, :]
        particles_y_after[n_neurons_each_dim * neu:n_neurons_each_dim * (neu + 1)] = orientations[neu, 0, 0]

    # Compute joint probability distribution
    joint_pdf = np.zeros((n_bins, n_bins))
    for i in range(n_neurons):
        joint_pdf_particle = scipy.stats.multivariate_normal(mean=[particles_x_after[i], particles_y_after[i]],
                                                             cov=[[std, 0], [0, std]]).pdf(pos)
        joint_pdf += joint_pdf_particle
    joint_pdf = joint_pdf / np.sum(joint_pdf)

    # Estimate KL
    kl = np.sum(scipy.special.rel_entr(np.ndarray.flatten(true_joint), np.ndarray.flatten(joint_pdf)))
    print("KL: ",kl)
    kl_runs.append(kl)


    # Colormap for pdf
    cmap = sns.color_palette("coolwarm", as_cmap=True)

    # Get true joint pdf
    reward_bins_pdf = np.linspace(40, 60, n_bins * 5)
    orientation_bins_pdf = np.linspace(20, 40, n_bins * 5)
    reward_mesh_pdf, orientation_mesh_pdf = np.meshgrid(reward_bins_pdf, orientation_bins_pdf)
    pos_pdf = np.dstack((reward_mesh_pdf, orientation_mesh_pdf))
    pdf_true = scipy.stats.multivariate_normal(mean, cov).pdf(pos_pdf)
    pdf_true = pdf_true / np.sum(pdf_true)

    # Figure with true pdf and learnt quantiles
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))  # ,tight_layout=True
    ax.spines['left'].set_linewidth(linewidth)
    ax.spines['bottom'].set_linewidth(linewidth)
    ax.tick_params(width=0)
    ax.set_box_aspect(1)
    ax.set_ylabel("Orientation (" + u"\u00b0" + ")")
    ax.set_xlabel("Reward")
    im = ax.imshow(pdf_true.T,
                   extent=[reward_bins_pdf[0], reward_bins_pdf[-1], orientation_bins_pdf[-1], orientation_bins_pdf[0]],
                   cmap=cmap)
    ax.scatter(particles_x_after, particles_y_after, color="k", s=scatter_size)
    # ax.scatter(rewards_e,orientation_e,s=scatter_size)
    ax.set_xlim([reward_flat[0], reward_flat[-1]])
    ax.set_ylim([orientation_flat[0], orientation_flat[-1]])


    # Function to draw a rotated vertical line
    def draw_rotated_vline(x_value, ymin, ymax, ax, **kwargs):
        # Two points on the vertical line
        p1 = np.array([x_value, ymin])
        p2 = np.array([x_value, ymax])

        # Shift to origin, rotate, then shift back
        p1_rot = rotation_matrix @ (p1 - mean) + mean
        p2_rot = rotation_matrix @ (p2 - mean) + mean

        # Draw line between the two rotated points
        ax.plot([p1_rot[0], p2_rot[0]], [p1_rot[1], p2_rot[1]], **kwargs)


    # Draw samples we are note giving
    draw_rotated_vline(left_reward, 25, 35, ax, ls="--", color="k")
    draw_rotated_vline(right_reward, 25, 35, ax, ls="--", color="k")
    ax.set_xlim([45.5, 54.5])
    ax.set_ylim([25.5, 34.5])
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, shrink=0.5, ticks=[])
    fig.savefig("quantiles_circle_"+str(run)+".pdf")


# Save KL for each run
kl_runs = np.array(kl_runs)
np.save("quantiles_circle_kl_runs.npy",kl_runs)

pdb.set_trace()