import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#import plotly.express as px
import sklearn.metrics
import math
import datetime as dt
import torch
from model import *
import argparse
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler

#importing data
data_path = './household_power_consumption.txt'

data = pd.read_csv(data_path, sep=';',
                   parse_dates={'data': ['Date', 'Time']},
                   infer_datetime_format=True,
                   na_values=['nan', '?'],
                   index_col='data')
data_clear = data.dropna()
data_resample = data_clear.resample('30Min').mean()
data_resample.shape

# Splitting into target variable and feathers
X = data_resample.drop("Global_active_power", axis=1)
X = X.reset_index(drop=True)
X.head()

y = data_resample["Global_active_power"].reset_index(drop=True)
y.head()

from sklearn.model_selection import train_test_split
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Drop missing values
X_train = X_train.dropna()

X_test = X_test.dropna()

y_train = y_train.dropna()

y_test = y_test.dropna()

# NORMALIZE DATA

# scaler
scaler = MinMaxScaler(feature_range=(0, 1))

# Normalize X_train
X_train_norm = scaler.fit_transform(X_train)
print(X_train.shape)
print(y_train.shape)

# Normalize X_test
X_test_norm = scaler.fit_transform(X_test)
print(X_test.shape)
print(y_test.shape)


X_train = torch.from_numpy(X_train_norm).float()
X_test = torch.from_numpy(X_test_norm).float()
y_train = torch.Tensor(y_train.to_numpy()).float()
y_test = torch.Tensor(y_test.to_numpy()).float()

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--learning_rate', type=float, default=0.001,help="lr")
parser.add_argument('--maxlr', type=float, default=0.01,help="lr")
parser.add_argument('--epochs', type=int, default=100, help="epochs")
parser.add_argument('--batch_size', type=int, default=5000,help="batch size")
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



train_split = TensorDataset(X_train, y_train)
test_split = TensorDataset(X_test, y_test)

# create batches
train_batches = DataLoader(train_split, batch_size=args.batch_size, shuffle=True)
test_batches = DataLoader(test_split, batch_size=args.batch_size, shuffle=False)


input_size=6
order =2
model = OurModel(input_size,  order).cuda()
model = torch.nn.DataParallel(model)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.maxlr, pct_start=0.25,
                                                                 steps_per_epoch=len(train_batches), epochs=args.epochs)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

#train
for epoch in range(args.epochs):
    model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()

        #gradient norm clipping
        for m in model.parameters():
            m.grad.data.clamp_(-1, 1)

        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        #if idx%(len(train_batches)//2)==0:
        #    print("epoch={}/{},{}/{}of train, loss={}".format(
        #        epoch, args.epochs, idx, len(train_batches),loss.item()))
        scheduler.step()
    train_epochs_loss.append(np.average(train_epoch_loss))

#test
    model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(test_batches,0):
        data_x = data_x.to(device)
        data_y = data_y.to(device)
        outputs = model(data_x).squeeze()
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
print(train_epochs_loss)
print(valid_epochs_loss)
