from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import argparse
import numpy as np
from model import *
import torch
import matplotlib.pyplot as plt
import time

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--learning_rate', type=float, default=0.0001,help="lr")
parser.add_argument('--maxlr', type=float, default=0.001,help="lr")
parser.add_argument('--epochs', type=int, default=2000, help="epochs")
parser.add_argument('--batch_size', type=int, default=25000,help="batch size")
args = parser.parse_args()

# def generate_sample():
#     x = torch.randn(10000,10)
#     y = torch.pow(x[:,0]+3*x[:,1]+2*x[:,2]+1.1*x[:,3]+1.2*x[:,4]+1.3*x[:,5]+1.4*x[:,6]+1.5*x[:,7]+1.6*x[:,8]+1.7*x[:,9],5)
#     return x,y

def generate_sample():
    x = torch.randn(50000,10)
    y = torch.pow(x[:,0],5)+3*torch.pow(x[:,1],4)+2*torch.pow(x[:,2],3)+5*x[:,3]*x[:,4]+3*torch.pow(x[:,5],2)+2*x[:,6]*x[:,7]*x[:,8]+2*x[:,9]
    return x,y

#experiment 2: simulation two  for d=10 and q=5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

X,Y = generate_sample()
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1)

#adding noise for training data
y_train+=torch.normal(0, 1, y_train.shape)


train_split = TensorDataset(X_train, y_train)
test_split = TensorDataset(X_test, y_test)

# create batches
train_batches = DataLoader(train_split, batch_size=args.batch_size, shuffle=True)
test_batches = DataLoader(test_split, batch_size=args.batch_size, shuffle=False)


input_size=10
order =5
#model =Feed_foward_same_depth(10, 80 ,5).cuda()
model =Feed_foward_same_width(10, 5).cuda()
#model = OurModel(input_size,  order).cuda()

model = torch.nn.DataParallel(model)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.maxlr, pct_start=0.25,final_div_factor=10,
                                                                 steps_per_epoch=len(train_batches), epochs=args.epochs)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

#train
data_start = time.time()
for epoch in range(args.epochs):
    model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()

        #gradient norm clipping
        for m in model.parameters():
            m.grad.data.clamp_(-1, 1)
        
        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        #if idx%(len(train_batches)//2)==0:
        #    print("epoch={}/{},{}/{}of train, loss={}".format(
        #        epoch, args.epochs, idx, len(train_batches),loss.item()))
        scheduler.step()
    train_epochs_loss.append(np.average(train_epoch_loss))
    #test
    model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(test_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))

print(time.time()-data_start)

print(train_epochs_loss)
print(valid_epochs_loss)

'''
x = np.linspace(-10, 10, 100)
y = np.linspace(-10, 10, 100)

X, Y = np.meshgrid(x, y)
X=torch.from_numpy(X).float()
Y=torch.from_numpy(Y).float()
Z=model(torch.cat((torch.unsqueeze(X,dim=2),torch.unsqueeze(Y,dim=2)),dim=2).reshape(-1,2).cuda()).reshape(100,100)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z.cpu().detach().numpy(), rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
fig.savefig("Our_fitting.png")
plt.show()
'''

'''
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[100:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[100:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.savefig("experiment1.png")
plt.show()
'''