import argparse
import time
import torch
import numpy as np
import scipy.sparse as sparse
import torch.optim as optim
from sklearn.metrics import roc_auc_score, f1_score
from model import Uni_model
import gc
import os
import metric
import utils

parser = argparse.ArgumentParser(description='PyTorch Neural Collaborative Filtering')
parser.add_argument('--dataset', type=str, default='cd')
parser.add_argument('--lr', type=float, default=0.002, help='initial learning rate')
parser.add_argument('--wd', type=float, default=0.000001, help='weight decay coefficient')
parser.add_argument('--batch_size', type=int, default=2048, help='batch size')
parser.add_argument('--epochs', type=int, default=20, help='upper epoch limit')
parser.add_argument('--embedding_size', type=int, default=64, help='embedding_size')
parser.add_argument('--seed', type=int, default=55, help='random seed')
parser.add_argument('--cuda', action='store_true', help='use CUDA')
parser.add_argument('--test', action='store_true', help='test')
parser.add_argument('--cold', action='store_false', help='use CUDA')
parser.add_argument('--cluster', action='store_true', help='use CUDA')
parser.add_argument('--pretrain', action='store_false', help='use CUDA')
parser.add_argument('--feature', type=str, default='popularity')
parser.add_argument('--home_dir', type=str, default='/home/')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='report interval')
parser.add_argument('--save', type=str, default='model.pt', help='path to save the final model')
parser.add_argument('--save_output', type=str, default='model', help='Save identifier')
parser.add_argument('--k', type=int, default=100)
parser.add_argument('--n', type=int, default=2)
parser.add_argument('--reg_loss', action='store_true')
parser.add_argument('--pre_stat', action='store_true')

args = parser.parse_args()
# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

###############################################################################
# Load data
###############################################################################
home_dir = args.home_dir
# Training: Batch size is based on # interactions.
params = {'batch_size': args.batch_size,
          'shuffle': True,
          'num_workers': 5, 'pin_memory': True}
device = torch.device("cuda" if args.cuda else "cpu")
# Evaluation: Batch-size is based on # users.
eval_params = {'batch_size': 128,
               'shuffle': False,
               'num_workers': 5, 'pin_memory': True}

# Define train/val/test datasets
save, user_feature, item_feature, user_stat, item_stat = utils.load_feature(home_dir, args.dataset, args.feature, args.k)
if args.pre_stat:
    train_stat, test_stat = utils.load_pre_stat(home_dir, args.dataset)
else:
    train_stat = None
    test_stat = None
# user_feature = user_feature.long()
print(user_feature.shape, item_feature.shape)
print(user_feature[0])
print(item_feature[10])
n_items = len(item_feature)
n_users = len(user_feature)

if not args.cold:
    with open('{}/data/uni_recsys/bpr/{}.pt'.format(home_dir, args.dataset), 'rb') as f:
        base_model = torch.load(f)
    user_embedding = base_model.userEmbedding.weight
    item_embedding = base_model.itemEmbedding.weight

    if not args.pretrain:
        print("Loading Pretrained Embedding")
        with open(pretrain_path, 'rb') as f:
            feature_model = torch.load(f)
        feature_user_embedding = feature_model.userFeatureEmbedding(torch.LongTensor(torch.arange(n_users)).to(device))
        feature_item_embedding = feature_model.itemFeatureEmbedding(torch.LongTensor(torch.arange(n_items)).to(device))
        with open(pretrain_path1, 'rb') as f:
            feature_model = torch.load(f)
        clustering_user_embedding = feature_model.userFeatureEmbedding(torch.LongTensor(torch.arange(n_users)).to(device))
        clustering_item_embedding = feature_model.itemFeatureEmbedding(torch.LongTensor(torch.arange(n_items)).to(device))
    else:
        feature_user_embedding = None
        feature_item_embedding = None
        clustering_item_embedding = None
        clustering_user_embedding = None
else:
    user_embedding = None
    item_embedding = None
    feature_user_embedding = None
    feature_item_embedding = None
    clustering_user_embedding = None
    clustering_item_embedding = None

if 'yelp_AZ' in args.dataset:
    alter_data = 'yelp_AZ'
elif 'movie' in args.dataset:
    alter_data = 'movie_lens'
else:
    alter_data = args.dataset.split('_')[0]

train_dataset = utils.TrainDatasetNeuMF(home_dir, user_stat, item_stat, args.dataset, args.cold, n_users, n_items)
val_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, alter_data, args.cold, n_users, n_items, 'val')
val_loader = torch.utils.data.DataLoader(val_dataset, **eval_params)
test_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, alter_data, args.cold, n_users, n_items, 'test')
cold_test_dataset = utils.EvalDatasetNeuMF(home_dir, user_stat, item_stat, alter_data, args.cold, n_users, n_items, 'test', 'c_c')


# Define data loaders.
train_loader = torch.utils.data.DataLoader(train_dataset, **params)

test_loader = torch.utils.data.DataLoader(test_dataset, **eval_params)
cold_test_loader = torch.utils.data.DataLoader(cold_test_dataset, **eval_params)
###############################################################################
# Build the model
###############################################################################

model = Uni_model(args.embedding_size, n_users, n_items, user_feature, item_feature, user_embedding, item_embedding
                  , args.cold, args.pretrain, feature_user_embedding, feature_item_embedding
                  , train_stat, test_stat).to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)


def evaluate(model, eval_loader):
    model.eval()
    truths = []
    preds = []
    n_10 = []
    r_10 = []
    with torch.no_grad():
        for batch_idx, data in enumerate(eval_loader):
            data = [x.to(device, non_blocking=True) for x in data]
            (users, items, label, u_stat, i_stat) = data
            scores = model(users, items, u_stat, i_stat).view(-1, label.shape[1])  # [B * I]
            # scores = scores * u_stat.squeeze() * i_stat.squeeze()
            n_10.append(metric.NDCG_binary_at_k_batch_torch(scores, label, 10, device=device))
            r_10.append(metric.Recall_at_k_batch_torch(scores, label, 10))
            preds.append(scores.view(-1, 1).cpu().numpy())
            truths.append(label.view(-1, 1).cpu().numpy())
    pred = np.concatenate(preds)
    truth = np.concatenate(truths)
    # print(torch.mean(torch.cat(n_10)), torch.mean(torch.cat(r_10)))
    auc = roc_auc_score(truth, pred)
    # auc=0
    best_lambd = -1

    return auc, best_lambd, truth, pred, torch.mean(torch.cat(n_10)), torch.mean(torch.cat(r_10))


print("# train batches :", len(train_loader))

args.log_interval = len(train_loader.dataset) // train_loader.batch_size
if args.log_interval == 0:
    args.log_interval =1

# At any point you can hit Ctrl + C to break out of training early.
def train(train_loader, val_loader):
    best_auc = -np.inf
    for epoch in range(0, args.epochs):
        epoch_start_time = time.time()
        # Train start.
        model.train()
        train_loss = 0.0
        start_time = time.time()
        for batch_idx, data in enumerate(train_loader):
            data = [x.to(device, non_blocking=True) for x in data]
            (user, item, neg_user, neg_item, u_stat, i_stat, negu_stat, negi_stat) = data
            model.zero_grad()
            optimizer.zero_grad()

            pos_predictions = model(user, item, u_stat, i_stat)
            neg_predictions = model(neg_user, neg_item, negu_stat, negi_stat)
            loss = model.loss_function(pos_predictions.squeeze(), neg_predictions.squeeze())

            # scores = model(user, item)
            # loss = model.multi_loss(scores, label)
            if args.reg_loss:
                loss += model.reg_loss(args.n)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

        elapsed = time.time() - start_time
        print('| epoch {:3d} | {:4d}/{:4d} batches | ms/batch {:4.2f} | '
              'loss {:4.2f}'.format(
            epoch + 1, batch_idx, len(range(0, len(train_loader.dataset), train_loader.batch_size)),
            elapsed * 1000 / args.log_interval, train_loss / args.log_interval))

        if epoch % 1 == 0:
            auc, lambd, truth, pred, n10, r10 = evaluate(model, val_loader)
            print('| end of epoch {:3d} | time: {:4.2f}s | auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(epoch + 1,
                                                                                              time.time() - epoch_start_time,
                                                                                              auc, r10, n10))
            print('-' * 89)
            # torch.save(model, save)
            if r10 > best_auc:
                torch.save(model, save)
                best_auc = r10
    return None


if __name__ == '__main__':
    args.log_interval = len(train_loader.dataset) // train_loader.batch_size
    # At any point you can hit Ctrl + C to break out of training early.
    try:
        if not args.test:
            print('Starting training')
            train(train_loader, test_loader)
            with open(save, 'rb') as f:
                model = torch.load(f)
            auc, lambd, truth, pred, n10, r10 = evaluate(model, val_loader)
            print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
            auc, lambd, truth, pred, n10, r10 = evaluate(model, cold_test_loader)
            print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
        else:
            print('Start testing')
            with open(save, 'rb') as f:
                model = torch.load(f)
            auc, lambd, truth, pred, n10, r10 = evaluate(model, test_loader)
            print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
            auc, lambd, truth, pred, n10, r10 = evaluate(model, cold_test_loader)
            print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))

    except KeyboardInterrupt:
        print('-' * 89)
        with open(save, 'rb') as f:
            model = torch.load(f)

        auc, lambd, truth, pred, n10, r10 = evaluate(model, test_loader)
        print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))
        auc, lambd, truth, pred, n10, r10 = evaluate(model, cold_test_loader)
        print('| auc: {:5.3f} | r10: {:5.3f} | n10: {:5.3f}'.format(auc, r10, n10))

        print('Exiting from training early')

