import torch
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
import torch.distributions as D
from scipy.interpolate import UnivariateSpline
import scipy.stats as st


torch.manual_seed(1)
n = 500
beta = 20
x = torch.randn(n)
x = x.sort()[0]


# define sign
def ind(x):
  return torch.where(x>=0, 1, 0)

def ind_0(x):
  return torch.where(x>0, 1, 0)

# create A, b1, b2
A3_pre = x.unsqueeze(1) - x
A3 = torch.where(A3_pre>=0, A3_pre, 0)
A4_pre = -A3_pre
A4 = torch.where(A4_pre>=0, A4_pre, 0)
A1 = A3 - A3.mean(axis=0)
A2 = A4 - A4.mean(axis=0)
b1 = ind(A3_pre).sum(axis=0)
b2 = ind_0(A3_pre).sum(axis=0)
b3 = -ind(A4_pre).sum(axis=0)
b4 = -ind_0(A4_pre).sum(axis=0)
A_stack = torch.hstack([A1, A1, A2, A2])
b_stack = torch.hstack([b1,b2,b3,b4])
beta = torch.max(b_stack)-10

# solve cvxpy
y_nd = cp.Variable(4*n)
obj_nd = cp.sum_squares(A_stack@y_nd)/2 + b_stack@y_nd + beta*cp.norm(y_nd,1) 
prob_nd = cp.Problem(cp.Minimize(obj_nd))
prob_nd.solve()
print('cvxpy opt value: ', prob_nd.value)


#reconstruct nn
def score_matching_relu_noskip(model, samples):
  pred_score = model(samples)
  loss_1 = (pred_score**2).sum()/2
  loss_2 = (ind(model.fc1(samples))*(model.fc1.weight.squeeze()*model.fc2.weight.squeeze())).sum()
  loss_3 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return (loss_1 + loss_2+ loss_3).squeeze()

class net_n(torch.nn.Module):
  def __init__(self):
    super(net_n, self).__init__()
    self.fc1 = torch.nn.Linear(1, 4*n)
    self.fc2 = torch.nn.Linear(4*n, 1)
  def forward(self, x):
      x = self.fc1(x)
      x = torch.nn.functional.relu(x)
      x = self.fc2(x)
      return x


y_opt = torch.Tensor(y_nd.value)
y_mn_1 = (y_nd.value[0]-y_nd.value[1499])/2
y_mn_n = -(y_nd.value[0]-y_nd.value[1499])/2
t = y_mn_n*4/5
print('t=',t)
t = 0
y0 = torch.zeros(n)
y0[0] = y_mn_1 + t
y1 = torch.zeros(n)
y2 = torch.zeros(n)
y2[-1] = y_mn_n + t
y3 = torch.zeros(n)
b0 = -(torch.hstack([A3, A3, A4, A4])@torch.hstack([y0,y1,y2,y3])).mean(axis=0)

eps=1e-6
# parameter 1...n
W1_rec = torch.ones(n)
b1_rec = -x
alpha1_rec = torch.Tensor(y0)
gamma1_rec = torch.sqrt(torch.abs(alpha1_rec/W1_rec))
W1_rec = W1_rec*gamma1_rec
b1_rec = b1_rec*gamma1_rec
alpha1_rec = torch.nan_to_num(alpha1_rec/gamma1_rec,0)
# parameter n+1...2n
W2_rec = -torch.ones(n)
b2_rec = -x-eps
alpha2_rec = torch.Tensor(y1)
gamma2_rec = torch.sqrt(torch.abs(alpha2_rec/W2_rec))
W2_rec = W2_rec*gamma2_rec
b2_rec = b2_rec*gamma2_rec
alpha2_rec = torch.nan_to_num(alpha2_rec/gamma2_rec,0)
# parameter 2n+1...3n
W3_rec = -torch.ones(n)
b3_rec = x
alpha3_rec = torch.Tensor(y2)
gamma3_rec = torch.sqrt(torch.abs(alpha3_rec/W3_rec))
W3_rec = W3_rec*gamma3_rec
b3_rec = b3_rec*gamma3_rec
alpha3_rec = torch.nan_to_num(alpha3_rec/gamma3_rec,0)
# parameter 3n+1...4n
W4_rec = -torch.ones(n)
b4_rec = x-eps
alpha4_rec = torch.Tensor(y3)
gamma4_rec = torch.sqrt(torch.abs(alpha4_rec/W4_rec))
W4_rec = W4_rec*gamma4_rec
b4_rec = b4_rec*gamma4_rec
alpha4_rec = torch.nan_to_num(torch.nan_to_num(alpha4_rec/gamma4_rec,0),0)
# b0
b0_nd = b0
# concat
W_nd = torch.concat([W1_rec, W2_rec, W3_rec, W4_rec])
b_nd = torch.concat([b1_rec, b2_rec, b3_rec, b4_rec])
alpha_nd = torch.concat([alpha1_rec, alpha2_rec, alpha3_rec, alpha4_rec])

  
model_nd = net_n()
model_nd.fc1.weight = torch.nn.Parameter(W_nd.unsqueeze(1))
model_nd.fc1.bias = torch.nn.Parameter(b_nd)
model_nd.fc2.weight = torch.nn.Parameter(alpha_nd.unsqueeze(1).T)
model_nd.fc2.bias = torch.nn.Parameter(b0_nd.unsqueeze(0).unsqueeze(1))

print('reconstruct nn loss: ', score_matching_relu_noskip(model_nd, x.unsqueeze(1)).item())

# non-convex nn
num_epoch = 500
trials = 10
loss_list = torch.zeros(trials, num_epoch)
for i in range(trials):
  model_torch = net_n()
  optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = score_matching_relu_noskip(model_torch, x.unsqueeze(1))
    loss_list[i][t] = loss.item()
    # print(t, loss.item())
    loss.backward()
    optimizer.step()

opt_nn = score_matching_relu_noskip(model_nd, x.unsqueeze(1)).item()
print('final loss compare ', opt_nn, torch.min(loss_list).item())


plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
for i in range(trials):
  plt.plot(torch.linspace(100,500,400), loss_list[i][100:500], linewidth=5.0)
# plt.legend(['cvx'])
plt.xlabel('# Epochs',fontsize=30)
plt.ylabel('Training Loss',fontsize=30)
plt.tight_layout()
plt.savefig('relu_noskip.pdf')
plt.show()

# predict score
x_hat = torch.linspace(-10,10,10000)
y_hat = model_nd(x_hat.unsqueeze(1)).detach().squeeze()
plt.plot(x_hat,y_hat, linewidth=5.0)
plt.xlabel('x',fontsize=30)
plt.ylabel("Score",fontsize=30)
# plt.xlim([min(x)*1.5,max(x)*1.5])
# plt.ylim([-2.1,4])
plt.tight_layout()
plt.savefig('relu_noskip_score.pdf')
plt.show()


eps = torch.Tensor([0.5])
steps = 100
n_samples = 100000

# d/dx -log p(x) for N(0, 1)
force = lambda x: model_nd(x.unsqueeze(1)).detach().squeeze()
# Prior distribution from which the initial positions are sampled
prior = D.Uniform(-20, 20)
# Run Langevin Dynamics
z = prior.sample((n_samples, ))
for j in range(5):
    for i in range(steps):
        # print(i)
        z = z + eps * force(z) + torch.sqrt(2 * eps) * torch.randn(size=z.shape)
    eps=eps/2
s = z.detach()

N = 50000
n_ = N//5000

plt.hist(s, density=True, bins=50)
mn, mx = plt.xlim()
plt.xlim(mn, mx)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(s)
plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
plt.legend(loc="upper left")
# plt.ylabel("Probability")
# plt.xlabel("Data")
# plt.title("Histogram")
plt.xlabel("Sample",fontsize=30)
plt.ylabel("Frequency",fontsize=30)
plt.tight_layout()
plt.savefig('relu_noskip_hist.pdf')
plt.show()






