import sys

sys.path.insert(0, "../../../../")

import time
import torch
import pickle
import argparse
import numpy as np
import pandas as pd
from random import shuffle, sample

from ganrl.embedding.wide_and_deep import WideAndDeep

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = "cpu"

parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=50, help='batch size')
args = parser.parse_args()

begin = time.time()
print("Load sparse_f")
with open("./ml-latest-small/sparse_f.pkl", "rb") as file:
    sparse_f = pickle.load(file)
    sparse_f = np.asarray(sparse_f)
print("Took: ", time.time() - begin)

begin = time.time()
print("Load dense_f")
with open("./ml-latest-small/dense_f.pkl", "rb") as file:
    dense_f = pickle.load(file)
    dense_f = np.asarray(dense_f)
print("Took: ", time.time() - begin)

begin = time.time()
print("Load labels")
with open("./ml-latest-small/label.pkl", "rb") as file:
    labels = pickle.load(file)
    labels = np.asarray(labels)
print("Took: ", time.time() - begin)

print(np.asarray(sparse_f).shape, np.asarray(dense_f).shape, np.asarray(labels).shape)

all_sessions = np.arange(sparse_f.shape[0]).tolist()
shuffle(all_sessions)
anchor = int(len(all_sessions) * 0.7)
train_ids, test_ids = all_sessions[: anchor], all_sessions[anchor:]

model = WideAndDeep(in_deep_dim=sparse_f.shape[1], in_wide_dim=dense_f.shape[1]).to(device=DEVICE)
opt = torch.optim.Adam(params=model.parameters())

# === train the network
for epoch in range(args.num_epochs):
    _user_ids = sample(train_ids, k=args.batch_size)
    _sparse_f = sparse_f[_user_ids]
    _dense_f = dense_f[_user_ids]
    _labels = labels[_user_ids]

    inputs = [torch.tensor(_sparse_f.astype(np.float32), device=DEVICE),
              torch.tensor(_dense_f.astype(np.float32), device=DEVICE)]
    y_hat, _ = model(inputs)

    """ Original code goes as follows

        ```python
        difference_y = y - y_label
        loss = tf.reduce_mean(tf.multiply(difference_y, difference_y))
        opt = tf.train.AdamOptimizer(learning_rate=0.001)
        train_op = opt.minimize(loss)
        ```
    """

    opt.zero_grad()
    error = torch.tensor(labels.astype(np.float32), device=DEVICE) - y_hat
    loss = torch.mean(torch.mul(error, error))  # TODO: MSE...?
    loss.backward()
    opt.step()
    print(epoch, loss.item())

# === Evaluate the model
_user_ids = sample(train_ids, k=args.batch_size)
_sparse_f = sparse_f[_user_ids]
_dense_f = dense_f[_user_ids]
_labels = labels[_user_ids]

inputs = [torch.tensor(_sparse_f.astype(np.float32), device=DEVICE),
          torch.tensor(_dense_f.astype(np.float32), device=DEVICE)]

with torch.no_grad():
    y_hat, _ = model(inputs)
    y_hat = y_hat.cpu().detach().numpy().ravel()
    y_hat = 1.0 * (y_hat >= 0.5)
    result = "{} {}".format(_labels, y_hat)
    accuracy = np.sum(_labels == y_hat).astype(np.float32) / (args.batch_size * 1.0)
    result += "Accuracy: {}".format(accuracy)
    print(_labels, y_hat)
    print(accuracy)

with open("./ml-latest-small/item_embedding_training_result.txt", "w") as file:
    file.write(result)

# === Get the embedding
df_sparse = pd.read_csv("./ml-latest-small/sparse_feature.csv")
df_sparse["sparse_f"] = df_sparse["sparse_f"].apply(
    lambda x: np.fromstring(x.replace('\n', '').replace('[', '').replace(']', '').replace('  ', ' '), sep=' '))
df_dense = pd.read_csv("./ml-latest-small/dense_feature.csv")
df_dense["dense_f"] = df_dense["dense_f"].apply(
    lambda x: np.fromstring(x.replace('\n', '').replace('[', '').replace(']', '').replace('  ', ' '), sep=' '))
print(df_sparse.shape, df_dense.shape)
print(df_sparse.columns, df_dense.columns)

all_sparse_f = np.asarray([i for i in df_sparse["sparse_f"].values]).astype(np.float32)
all_dense_f = np.asarray([i for i in df_dense["dense_f"].values]).astype(np.float32)

inputs = [torch.tensor(all_sparse_f, device=DEVICE), torch.tensor(all_dense_f, device=DEVICE)]

_, item_embedding = model(inputs)
print(item_embedding.shape)

with open("../ml-latest-small/item_embedding.pkl", "wb") as file:
    pickle.dump(item_embedding.cpu().detach().numpy(), file, protocol=pickle.HIGHEST_PROTOCOL)
