import torch
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import torch.distributions as D
from scipy.interpolate import UnivariateSpline
import copy
import scipy.stats as st



torch.manual_seed(1)
n = 500
x = torch.randn(n)
x = x.sort()[0]
beta = 20
# define sign
def sign_torch_1(x):
  x = torch.where(torch.abs(x)>1.5e-7, x, 0)
  return torch.nan_to_num(x/torch.abs(x), nan=1)

# create A, b1, b2
A1_pre = x.unsqueeze(1) - x
A1 = torch.abs(A1_pre)
A = A1 - A1.mean(axis=0)
b1 = sign_torch_1(A1_pre).sum(axis=0)
b2 = -sign_torch_1(-A1_pre).sum(axis=0)
A_stack = torch.hstack([A, A])
b_stack = torch.hstack([b1,b2])
# beta = torch.max(b_stack)-5

# solve cvxpy
y_nd = cp.Variable(2*n)
obj_nd = cp.sum_squares(A_stack@y_nd)/2 + b_stack@y_nd + beta*cp.norm(y_nd,1) 
prob_nd = cp.Problem(cp.Minimize(obj_nd))
prob_nd.solve()
print('cvxpy opt value: ', prob_nd.value)
y_opt = torch.Tensor(y_nd.value)
print((torch.abs(y_opt)>1e-6).nonzero(as_tuple=True)[0])


#reconstruct nn
def score_matching(model, samples):
  pred_score = model(samples)
  loss_1 = (pred_score**2).sum()/2
  loss_2 = sign_torch_1(model.fc1(samples))*(model.fc1.weight.squeeze()*model.fc2.weight.squeeze())
  loss_3 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return loss_1 + loss_2.sum() + loss_3

y_opt = torch.Tensor(y_nd.value)
y0 = y_opt[:n]
y1 = y_opt[-n:]
b0 = -(torch.hstack([A1, A1])@y_opt).mean(axis=0)

# parameter 1...n
W1_rec = torch.ones(n)
b1_rec = -x
alpha1_rec = torch.Tensor(y0)
gamma1_rec = torch.sqrt(torch.abs(alpha1_rec/W1_rec))
W1_rec = W1_rec*gamma1_rec
b1_rec = b1_rec*gamma1_rec
alpha1_rec = torch.nan_to_num(alpha1_rec/gamma1_rec,0)
# parameter n+1...2n
W2_rec = -torch.ones(n)
b2_rec = x
alpha2_rec = torch.Tensor(y1)
gamma2_rec = torch.sqrt(torch.abs(alpha2_rec/W2_rec))
W2_rec = W2_rec*gamma2_rec
b2_rec = b2_rec*gamma2_rec
alpha2_rec = torch.nan_to_num(alpha2_rec/gamma2_rec,0)
# b0
b0_nd = b0
# concat
W_nd = torch.concat([W1_rec, W2_rec])
b_nd = torch.concat([b1_rec, b2_rec])
alpha_nd = torch.concat([alpha1_rec, alpha2_rec])


class net_n(torch.nn.Module):
  def __init__(self):
    super(net_n, self).__init__()
    self.fc1 = torch.nn.Linear(1, 2*n)
    self.fc2 = torch.nn.Linear(2*n, 1)
  def forward(self, x):
      x = self.fc1(x)
      x = torch.abs(x)
      x = self.fc2(x)
      return x
  
model_nd = net_n()
model_nd.fc1.weight = torch.nn.Parameter(W_nd.unsqueeze(1))
model_nd.fc1.bias = torch.nn.Parameter(b_nd)
model_nd.fc2.weight = torch.nn.Parameter(alpha_nd.unsqueeze(1).T)
model_nd.fc2.bias = torch.nn.Parameter(b0_nd.unsqueeze(0).unsqueeze(1))

print('reconstruct nn loss: ', score_matching(model_nd, x.unsqueeze(1)).item())

# non-convex nn
num_epoch = 500
trials = 10
loss_list = torch.zeros(trials, num_epoch)
for i in range(trials):
  model_torch = net_n()
  optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = score_matching(model_torch, x.unsqueeze(1))
    loss_list[i][t] = loss.item()
    print(t, loss.item())
    loss.backward()
    optimizer.step()

opt_nn = score_matching(model_nd, x.unsqueeze(1)).item()
print('final loss compare ', opt_nn, torch.min(loss_list).item())



plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
for i in range(trials):
  plt.plot(torch.linspace(100,500,400),loss_list[i][100:num_epoch], linewidth=5.0)
# plt.legend(['cvx'])
plt.xlabel('# Epochs',fontsize=30)
plt.ylabel('Training Loss',fontsize=30)
plt.tight_layout()
plt.savefig('abs_noskip_%i.pdf'%n)
# plt.yscale('log')
plt.show()


# # predict score
x_hat = torch.linspace(-10,10,10000)
y_hat = model_nd(x_hat.unsqueeze(1)).detach().squeeze()
plt.plot(x_hat,y_hat, linewidth=5.0)
plt.xlabel('x',fontsize=30)
plt.ylabel("Score",fontsize=30)
# plt.xlim([min(x)*1.5,max(x)*1.5])
# plt.ylim([-2.1,4])
plt.tight_layout()
plt.savefig('abs_noskip_score.pdf')
plt.show()


eps = torch.Tensor([0.5])
steps = 100
n_samples = 100000

# d/dx -log p(x) for N(0, 1)
force = lambda x: model_nd(x.unsqueeze(1)).detach().squeeze()
# Prior distribution from which the initial positions are sampled
prior = D.Uniform(-20, 20)
# Run Langevin Dynamics
z = prior.sample((n_samples, ))
for j in range(5):
    for i in range(steps):
        # print(i)
        z = z + eps * force(z) + torch.sqrt(2 * eps) * torch.randn(size=z.shape)
    eps=eps/2
s = z.detach()

N = 50000
n_ = N//5000

plt.hist(s, density=True, bins=50)
mn, mx = plt.xlim()
plt.xlim(mn, mx)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(s)
plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
plt.legend(loc="upper left")
# plt.ylabel("Probability")
# plt.xlabel("Data")
# plt.title("Histogram")
plt.xlabel("Sample",fontsize=30)
plt.ylabel("Frequency",fontsize=30)
plt.tight_layout()
plt.savefig('abs_noskip_hist.pdf')
plt.show()






