import pdb
from aux import *
from sklearn.model_selection import train_test_split
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
from collections import OrderedDict
import scipy
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
mpl.use('TkAgg')



length_ticks=3
font_size=10
linewidth=1.2
scatter_size=0.5
horizontal_size=5
vertical_size=5

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams.update({'font.size': font_size})
mpl.rcParams['lines.linewidth'] = linewidth
mpl.rcParams['xtick.labelsize']=font_size
mpl.rcParams['ytick.labelsize']=font_size
mpl.rcParams['lines.linewidth']=linewidth

# Elipse case
# Circle case
mean=[50,30]
cov=[[2, 0], [0, 2]]


# Don't give sample in this range (not rotated in this case)
left_reward = 49.5
right_reward = 50.5
angle = 0
rotation_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])

# Generate data
nPix = 20  # size of gratings
images, rewards, orientations = make_incomplete_gaussian_gratings_rotate(1000, 1, nPix, 4, left_reward, right_reward,
                                                                         mean, cov, rotation_matrix)

# Split in train and test
images_train, images_test, rewards_train, rewards_test, orientations_train, orientations_test = train_test_split(images, rewards, orientations, test_size=.33, random_state=26)

# Loss gets reward and orientations
reward_orientation_train=np.vstack((rewards_train,orientations_train)).T
reward_orientation_test=np.vstack((rewards_test,orientations_test)).T
batch_size = 200

# Instantiate training and test data
train_data = Data(images_train, reward_orientation_train, orientations_train)
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)

test_data = Data(images_test, reward_orientation_test, orientations_test)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size, shuffle=True)

# True joint
n_bins=20
_,rewards_true,orientations_true=make_gaussian_gratings(1000,1,nPix,4,mean,cov)
reward_flat=np.linspace(40,60,n_bins+1)
orientation_flat=np.linspace(20,40,n_bins+1)
true_joint,_,_=np.histogram2d(rewards_true,orientations_true,[reward_flat,orientation_flat])
true_joint=true_joint/np.sum(true_joint)

reward_flat=np.linspace(40,60,n_bins)
orientation_flat=np.linspace(20,40,n_bins)
reward_mesh,orientation_mesh=np.meshgrid(reward_flat,orientation_flat)
pos = np.dstack((reward_mesh, orientation_mesh))


n_neurons=100
n_runs=10
std=2
kl_runs=[]
for run in range(n_runs):
    input_dim = nPix*nPix
    model = nn.Sequential(OrderedDict([
              ('flatten', nn.Flatten()),
              ('fc11', nn.Linear(input_dim,1024)),
              ('relu1', nn.ReLU()),
              ('fc2', nn.Linear(1024,256)),
              ('relu2', nn.ReLU()),
              ('fc3', nn.Linear(256,n_neurons*2))]))

    # 0.001 for gaussian
    learning_rate = 0.0001 #0.0001 for elipse
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    #num_epochs =600 #For gaussian
    num_epochs= 1000 #1000
    loss_values = []

    count=0

    for epoch in range(num_epochs):#

        for images_e, rewards_e, _ in train_dataloader:

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(images_e)
            pred=torch.reshape(pred,(pred.shape[0],n_neurons,2))


            loss= particle_loss_with_interaction(pred,n_neurons,rewards_e,6,1000,4)
            loss_values.append(loss.item())
            loss.backward()
            optimizer.step()


            particles=torch.reshape(pred,(pred.shape[0],n_neurons,2)).detach().numpy()[0,:,:]
            particles_x=particles[:,0]
            particles_y=particles[:,1]

            count+=1

            break

    print("Training Complete")

    # Get estimated particles
    pred_after = model(images_e)
    particles_after=torch.reshape(pred_after,(pred_after.shape[0],n_neurons,2)).detach().numpy()
    particles_after=np.mean(particles_after,axis=0)
    particles_x_after=particles_after[:,0]
    particles_y_after=particles_after[:,1]

    # Estimate joint pdf
    joint_pdf = np.zeros((n_bins, n_bins))
    for i in range(n_neurons):
        joint_pdf += scipy.stats.multivariate_normal(mean=[particles_x_after[i], particles_y_after[i]],cov=[[std, 0], [0, std]]).pdf(pos)
    joint_pdf = joint_pdf / np.sum(joint_pdf)

    # Compute KL
    kl = np.sum(scipy.special.rel_entr(np.ndarray.flatten(true_joint), np.ndarray.flatten(joint_pdf)))

    print("KL: ",kl)

    kl_runs.append(kl)


    # Compute true pdf
    cmap = sns.color_palette("coolwarm", as_cmap=True)
    reward_bins_pdf = np.linspace(40, 60, n_bins * 5)
    orientation_bins_pdf = np.linspace(20, 40, n_bins * 5)
    reward_mesh_pdf, orientation_mesh_pdf = np.meshgrid(reward_bins_pdf, orientation_bins_pdf)
    pos_pdf = np.dstack((reward_mesh_pdf, orientation_mesh_pdf))
    pdf_true = scipy.stats.multivariate_normal(mean, cov).pdf(pos_pdf)
    pdf_true = pdf_true / np.sum(pdf_true)

    # Figure with true pdf and quantiles
    fig, ax = plt.subplots(1, 1, figsize=(1.5, 1.5))  # ,tight_layout=True
    ax.spines['left'].set_linewidth(linewidth)
    ax.spines['bottom'].set_linewidth(linewidth)
    ax.tick_params(width=0)
    ax.set_box_aspect(1)
    ax.set_ylabel("Orientation (" + u"\u00b0" + ")")
    ax.set_xlabel("Reward")
    im = ax.imshow(pdf_true,extent=[reward_bins_pdf[0], reward_bins_pdf[-1], orientation_bins_pdf[-1], orientation_bins_pdf[0]],cmap=cmap)
    ax.scatter(particles_x_after, particles_y_after, color="k", s=scatter_size)
    ax.set_xlim([reward_flat[0], reward_flat[-1]])
    ax.set_ylim([orientation_flat[0], orientation_flat[-1]])
    #ax.axvline(x=left_reward, ls="--", color="k")
    #ax.axvline(x=right_reward, ls="--", color="k")
    ax.set_xlim([45.5, 54.5])
    ax.set_ylim([25.5, 34.5])
    ax.set_xticks([])
    ax.set_yticks([])
    plt.colorbar(im, shrink=0.5, ticks=[])

    # Function to draw a rotated vertical line
    def draw_rotated_vline(x_value, ymin, ymax, ax, **kwargs):
        # Two points on the vertical line
        p1 = np.array([x_value, ymin])
        p2 = np.array([x_value, ymax])

        # Shift to origin, rotate, then shift back
        p1_rot = rotation_matrix @ (p1 - mean) + mean
        p2_rot = rotation_matrix @ (p2 - mean) + mean

        # Draw line between the two rotated points
        ax.plot([p1_rot[0], p2_rot[0]], [p1_rot[1], p2_rot[1]], **kwargs)

    # Draw both rotated lines of samples that were not given
    draw_rotated_vline(left_reward, 25, 35, ax, ls="--", color="k")
    draw_rotated_vline(right_reward, 25, 35, ax, ls="--", color="k")

    fig.savefig("particles_circle_"+str(run)+".pdf")

kl_runs=np.array(kl_runs)
np.save("particles_circle_kl_runs.npy",kl_runs)

pdb.set_trace()
