from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import argparse
import numpy as np
from model import *
import torch
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--learning_rate', type=float, default=0.000001,help="lr")
parser.add_argument('--maxlr', type=float, default=0.00001,help="lr")
parser.add_argument('--epochs', type=int, default=600, help="epochs")
parser.add_argument('--batch_size', type=int, default=5000,help="batch size")
args = parser.parse_args()

def generate_sample():
    x = 10*torch.randn(10000,2)
    y = torch.pow(torch.norm(x,p=2,dim=1),2)
    return x,y



#experiment 1: easy one  for d=2 and q=2
# f(x,y)=x^2+y^2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

X,Y = generate_sample()
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1)

#adding noise for training data
y_train+=torch.normal(0, 1, y_train.shape)


train_split = TensorDataset(X_train, y_train)
test_split = TensorDataset(X_test, y_test)

# create batches
train_batches = DataLoader(train_split, batch_size=args.batch_size, shuffle=True)
test_batches = DataLoader(test_split, batch_size=args.batch_size, shuffle=False)


input_size=2
order =2
model = OurModel(input_size,  order).cuda()
model = torch.nn.DataParallel(model)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.maxlr, pct_start=0.25,
                                                                 steps_per_epoch=len(train_batches), epochs=args.epochs)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

#train
for epoch in range(args.epochs):
    model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()
        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        #if idx%(len(train_batches)//2)==0:
        #    print("epoch={}/{},{}/{}of train, loss={}".format(
        #        epoch, args.epochs, idx, len(train_batches),loss.item()))
        scheduler.step()
    train_epochs_loss.append(np.average(train_epoch_loss))

#test
    model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(test_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
print(train_epochs_loss)
print(valid_epochs_loss)

x = np.linspace(-10, 10, 100)
y = np.linspace(-10, 10, 100)

X, Y = np.meshgrid(x, y)
X=torch.from_numpy(X).float()
Y=torch.from_numpy(Y).float()
Z=model(torch.cat((torch.unsqueeze(X,dim=2),torch.unsqueeze(Y,dim=2)),dim=2).reshape(-1,2).cuda()).reshape(100,100)
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(X, Y, Z.cpu().detach().numpy(), rstride=1, cstride=1,
                cmap='viridis', edgecolor='none')
fig.savefig("Our_fitting.png")
plt.show()
'''
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[100:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[100:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.savefig("experiment1.png")
plt.show()
'''