from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import argparse
import numpy as np
from model import *
import torch
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--learning_rate', type=float, default=0.0001,help="lr")
parser.add_argument('--maxlr', type=float, default=0.001,help="lr")
parser.add_argument('--epochs', type=int, default=2000, help="epochs")
parser.add_argument('--batch_size', type=int, default=5000,help="batch size")
args = parser.parse_args()

def generate_sample():
    x = 10*torch.randn(50000,2)
    y = torch.pow(x[:,0],3)+torch.pow(x[:,1],3)
    return x,y



#experiment 1: easy one  for d=2 and q=2
# f(x,y)=x^2+y^2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

X,Y = generate_sample()
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.1)

#adding noise for training data
y_train+=torch.normal(0, 1, y_train.shape)


train_split = TensorDataset(X_train, y_train)
test_split = TensorDataset(X_test, y_test)

# create batches
train_batches = DataLoader(train_split, batch_size=args.batch_size, shuffle=True)
test_batches = DataLoader(test_split, batch_size=args.batch_size, shuffle=False)


input_size=2
order =3
model = OurModel(input_size,  order).cuda()
model = torch.nn.DataParallel(model)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.maxlr, pct_start=0.25,
                                                                 steps_per_epoch=len(train_batches), epochs=args.epochs)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

#train
for epoch in range(args.epochs):
    model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()
        
        for m in model.parameters():
            m.grad.data.clamp_(-1, 1)

        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        #if idx%(len(train_batches)//2)==0:
        #    print("epoch={}/{},{}/{}of train, loss={}".format(
        #        epoch, args.epochs, idx, len(train_batches),loss.item()))
        scheduler.step()
    train_epochs_loss.append(np.average(train_epoch_loss))

#test
    model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(test_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
print(train_epochs_loss)
print(valid_epochs_loss)

'''
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[100:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[100:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.savefig("experiment1.png")
plt.show()
'''