# This codes contains the pytorch implementation of AEMC-NE and the experiments of the synthetic data (Section 5.1 in the main paper)
import numpy as np
import time
from torch.nn.parameter import Parameter
import torch.optim as optim
import torch
from torch import nn
import random
import pandas as pd
torch.set_default_dtype(torch.float32)
from scipy import sparse

device = torch.device('cuda:0')
torch.cuda.empty_cache()

tanh=nn.Tanh()
sig=nn.Sigmoid()
d=10
n=3000
n1=50
n2=100#100
n3=300#300

torch.manual_seed(0)

W1=torch.rand(n1,d)*2-1
W2=torch.rand(n2,n1)*2-1
W3=torch.rand(n3,n2)*2-1

z=torch.rand(d,n)*2-1

def get_movie_len(missing_rate):
    # torch.manual_seed(0)
    z=torch.rand(d,n)*2-1
    X=(W3@tanh(W2@tanh(W1@z))).numpy()
    # print(W2@tanh(W1@z))
    X=np.cos(X)+X 
    whole_rating=X.T 
    rated_index_matrix=(whole_rating!=0)*1
    rated_num=np.sum(rated_index_matrix)
    missing_num=rated_num*missing_rate
    whole_1d=rated_index_matrix.flatten()
    test=whole_rating.flatten()
    num_1d=np.array(range(1,whole_rating.shape[0]*whole_rating.shape[1]+1))
    rated_index_1d=num_1d*whole_1d
    rated_index_1d=rated_index_1d[rated_index_1d!=0]
    train_index_1d=num_1d[np.random.choice(rated_index_1d,size=int(rated_num*((10-10*missing_rate)/10)),replace=False)-1]
    test[train_index_1d-1]=0
    test=test.reshape(whole_rating.shape)
    train=whole_rating-test
    train_index=(train!=0)*1
    missing_indicator=rated_index_matrix-train_index
    #print(train_index)
    return train,test,train_index,missing_indicator,missing_num

from torch.nn.modules.activation import ReLU
class RCAutoRec(nn.Module):
  def __init__(self,dim_in,layerwise_hidden_dim,elementwise_hidden_dim,mod):
    super(RCAutoRec, self).__init__()
    self.dim_in=dim_in
    self.dim_out=dim_in
    self.mod = mod
    ## layer-wise network
    actfun_LW=nn.Tanh()
    actfun_EW=nn.Tanh()
    ###############
    self.LNet = torch.nn.Sequential()
    nl_layerwise = np.size(layerwise_hidden_dim)
    self.LNet.add_module("L"+str(1), nn.Linear(self.dim_in, layerwise_hidden_dim[0]))
    self.LNet.add_module("L"+str(1)+'_actf', actfun_LW)
    for i in range(1,nl_layerwise):
        self.LNet.add_module("L"+str(i+1), nn.Linear(layerwise_hidden_dim[i-1], layerwise_hidden_dim[i]))
        self.LNet.add_module("L"+str(i+1)+'_actf', actfun_LW)
    self.LNet.add_module("L"+str(nl_layerwise+1), nn.Linear(layerwise_hidden_dim[-1], self.dim_out))
    # element-wise network
    self.ENet = torch.nn.Sequential()
    nl_elementwise = np.size(elementwise_hidden_dim)
    self.ENet.add_module("L"+str(1), nn.Linear(1, elementwise_hidden_dim[0]))
    self.ENet.add_module("L"+str(1)+'_actf', actfun_EW)
    for i in range(1,nl_elementwise):
        self.ENet.add_module("L"+str(i+1), nn.Linear(elementwise_hidden_dim[i-1], elementwise_hidden_dim[i]))
        self.ENet.add_module("L"+str(i+1)+'_actf', actfun_EW)
    self.ENet.add_module("L"+str(nl_elementwise+1), nn.Linear(elementwise_hidden_dim[-1], 1))

  def forward(self, x, mask):
    # Linear function  # LINEAR
    x = self.LNet(x)
    # print(mask.shape)
    if self.mod=='NE':
        train_num=int(torch.sum(mask))
        y = x[mask>0].reshape(train_num,1)
        y = self.ENet(y)+y*0.8
        x[mask>0]=y.reshape(train_num)
    return x
  def evaluate(self,x,x_test,mask):
    prediction=self.forward(x,mask)
    RMSE_squa=float(torch.norm(torch.mul(x_test-prediction,mask))**2/(torch.sum(mask>0)))
    RMSE_squa=RMSE_squa/float(torch.norm(torch.mul(x_test,mask))**2/(torch.sum(mask>0)))
    return RMSE_squa**(1/2)

def training(mr,mod,main_hidden_dim,act_hidden_dim,epoch,lr,wd):
  train_set,test_set,train_index,missing_indicator,missing_num=get_movie_len(mr)
  train_set,test_set,train_index,missing_indicator=map(torch.FloatTensor,(train_set,test_set,train_index,missing_indicator))
  train_index,missing_indicator=train_index.float(),missing_indicator.float()
  train_set,test_set,train_index,missing_indicator=train_set.to(device),test_set.to(device),train_index.to(device),missing_indicator.to(device)
  m_sample, dim_in = train_set.shape 
  net=RCAutoRec(dim_in,main_hidden_dim,act_hidden_dim,mod).to(device)
  optimizer = optim.AdamW(net.parameters(),lr=lr,weight_decay=wd)
  RMSE_list=[]
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1)
  for i in range(epoch):
    optimizer.zero_grad()
    x_pred = net.forward(train_set, train_index)
    loss=torch.norm(torch.mul(train_set-x_pred,train_index))**2/train_index.sum()
    RMSE=0
    if i >= 1:
        RMSE=net.evaluate(train_set,test_set,missing_indicator)
    if i>=1 and i<5 or i%20==0:
        print('epoch=',i,'loss=',loss.item(),'RMSE',RMSE,"   lr  {:.6f}".format(scheduler.get_last_lr()[0]))
    loss.backward()
    optimizer.step()
    scheduler.step()
    RMSE_list.append(RMSE)
  return RMSE_list[-1],min(RMSE_list),RMSE_list.index(min(RMSE_list))


mr=0.5 # missing rate
mod='NE' # NE (AEMC-NE) or not NE (AEMC)
main_hidden_dim=[300, 100, 30, 100, 300] # main network structure (hidden layers)
act_hidden_dim=[20,20] # element-wise network structure (hidden layers)
lr=1e-2 # learning rate
epoch=2000#
n_v=1
n_rep=5 # repeated trials
RE = np.zeros([n_rep,n_v])
for i in range(n_v):
  for j in range(n_rep):
    wd =0.1
    REi,mini,ind=training(mr,mod,main_hidden_dim,act_hidden_dim,epoch,lr,wd)
    RE[j,i]=REi
    print('trial',j,'wd', wd, 'recovery error', REi)
print(RE)
print('mean is',np.mean(RE,0),'std is',np.std(RE,0))




