import torch
import numpy as np
import cvxpy as cp
import torch.distributions as D
import scipy.stats as st
import matplotlib.pyplot as plt

torch.manual_seed(0)
beta = 0.5
n = 1000
num_epoch = 300
sigma = 0.1
x = torch.randn(n)
eps = torch.randn(n)
x_input = x + sigma*eps
label = -eps/sigma



# define model
class nn_type1(torch.nn.Module):
  def __init__(self):
    super(nn_type1, self).__init__()
    self.fc1 = torch.nn.Linear(1, n)
    self.fc2 = torch.nn.Linear(n, 1)
  def forward(self, x):
      x = self.fc1(x)
      x = torch.abs(x)
      x = self.fc2(x)
      return x

def dsm_loss(model, samples, labels):
  pred_score = model(samples)
  loss_1 = ((pred_score-labels.unsqueeze(1))**2).sum()/2
  loss_2 = beta*((model.fc1.weight**2).sum() + (model.fc2.weight**2).sum())/2
  return loss_1 + loss_2

# A_pre = x.unsqueeze(1) - x
# A = torch.abs(A_pre)
# tilde_A = A - A.mean(axis=0)
# tilde_y = label - label.mean()
# y = cp.Variable(n)
# obj = cp.sum_squares(tilde_A@y+tilde_y)/2 + beta*cp.norm(y,1)
# prob = cp.Problem(cp.Minimize(obj))
# prob.solve()
# print(prob.value)

# W_rec = torch.ones(n)
# b_rec = -x
# y_opt = torch.Tensor(y.value)
# alpha_rec = -torch.Tensor(y_opt)
# gamma_rec = torch.sqrt(torch.abs(alpha_rec/W_rec))
# W_rec = W_rec*gamma_rec
# b_rec = b_rec*gamma_rec
# alpha_rec = alpha_rec/gamma_rec
# b0_rec = (A@y_opt + label).mean()
# model_rec = nn_type1()
# model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
# model_rec.fc1.bias = torch.nn.Parameter(b_rec)
# model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
# model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))

# opt_nn = dsm_loss(model_rec, x.unsqueeze(1), label).item()
# print('reconstruct nn loss: ', opt_nn)

# # non-convex nn
# num_epoch = 500
# trials = 10
# loss_list = torch.zeros(trials, num_epoch)
# for i in range(trials):
#   model_torch = nn_type1()
#   optimizer = torch.optim.Adam(model_torch.parameters(), lr=1e-2, weight_decay=0)
#   for t in range(num_epoch):
#     optimizer.zero_grad()
#     loss = dsm_loss(model_torch, x.unsqueeze(1), label)
#     loss_list[i][t] = loss.item()
#     # print(t, loss.item())
#     loss.backward()
#     optimizer.step()

# print(torch.min(loss_list).item())
# plt.axhline(y = opt_nn, color = 'b', linestyle = 'dashed', linewidth=5.0)
# for i in range(trials):
#   plt.plot(torch.linspace(0,200,200), loss_list[i][0:200], linewidth=5.0)
# # plt.legend(['cvx'])
# plt.xlabel('# Epochs',fontsize=30)
# plt.ylabel('Training Loss',fontsize=30)
# plt.tight_layout()
# plt.savefig('dsm_type1_training.pdf')
# plt.show()








L = 10
sigmas = torch.linspace(1,0.01,L)
eps_0 = 2e-5
epss = eps_0*sigmas**2/sigmas[-1]**2
# print(epss)
T = 10

# cvx training
def cvx_train(x, label):
    A_pre = x.unsqueeze(1) - x
    A = torch.abs(A_pre)
    tilde_A = A - A.mean(axis=0)
    tilde_y = label - label.mean()
    y = cp.Variable(n)
    obj = cp.sum_squares(tilde_A@y+tilde_y)/2 + beta*cp.norm(y,1)
    prob = cp.Problem(cp.Minimize(obj))
    prob.solve()
    print('cvx opt value ', prob.value)
    W_rec = torch.ones(n)
    b_rec = -x
    y_opt = torch.Tensor(y.value)
    alpha_rec = -torch.Tensor(y_opt)
    gamma_rec = torch.sqrt(torch.abs(alpha_rec/W_rec))
    W_rec = W_rec*gamma_rec
    b_rec = b_rec*gamma_rec
    alpha_rec = alpha_rec/gamma_rec
    b0_rec = (A@y_opt + label).mean()
    model_rec = nn_type1()
    model_rec.fc1.weight = torch.nn.Parameter(W_rec.unsqueeze(1))
    model_rec.fc1.bias = torch.nn.Parameter(b_rec)
    model_rec.fc2.weight = torch.nn.Parameter(alpha_rec.unsqueeze(1).T)
    model_rec.fc2.bias = torch.nn.Parameter(b0_rec.unsqueeze(0).unsqueeze(1))
    return model_rec

# score = []
# for i in range(L):
#   print('training ', i, ' model')
#   x = torch.randn(n)
#   noise = torch.randn(n)
#   input = x + sigmas[i]*noise
#   label = -noise/sigmas[i]
#   score.append(cvx_train(input,label))

# train score functions
score = []
for i in range(L):
  x = torch.randn(n)
  noise = torch.randn(n)
  input = x + sigmas[i]*noise
  label = noise/sigmas[i]
  model = nn_type1()
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=0)
  for t in range(num_epoch):
    optimizer.zero_grad()
    loss = dsm_loss(model, input.unsqueeze(1), label)
    loss.backward()
    optimizer.step()
  score.append(model)



n_samples = 10000
prior = D.Uniform(-1, 1)
z = prior.sample((n_samples, )).unsqueeze(1)
for i in range(L):
  for t in range(T):
    z = z + (epss[i]/2) * score[i](z) + torch.sqrt(epss[i]) * torch.randn(size=z.shape)
    # print('z=',z)


s = z.detach().squeeze()
s = s[~torch.isnan(s)]
print('s=',s)
print('min s, max s', torch.min(s), torch.max(s))


plt.hist(s, density=True, bins=50)
mn, mx = plt.xlim()
plt.xlim(mn, mx) 
# plt.xlim(-10,10)
kde_xs = np.linspace(mn, mx, 300)
kde = st.gaussian_kde(s)
plt.plot(kde_xs, kde.pdf(kde_xs), linewidth=5.0)
# plt.legend(loc="upper left")
plt.xlabel("Sample",fontsize=30)
plt.ylabel("Frequency",fontsize=30)
plt.tight_layout()
plt.savefig('dsm_type1_noncvx_sample.pdf')
plt.show()


