"""Training GCMC model on the MovieLens data set.

The script loads the full graph to the training device.
"""
import argparse
import logging
import os
import random
import string
import time

import numpy as np
import torch as th
import torch.nn as nn
from data import MovieLens
from model import BiDecoder, GCMCLayer
from utils import (MetricLogger, get_activation, get_optimizer, torch_net_info,
                   torch_total_param_num)


class Net(nn.Module):
    def __init__(self, args):
        super(Net, self).__init__()
        self._act = get_activation(args.model_activation)
        self.encoder = GCMCLayer(
            args.rating_vals,
            args.src_in_units,
            args.dst_in_units,
            args.gcn_agg_units,
            args.gcn_out_units,
            args.gcn_dropout,
            args.gcn_agg_accum,
            agg_act=self._act,
            share_user_item_param=args.share_param,
            device=args.device,
        )
        self.decoder = BiDecoder(
            in_units=args.gcn_out_units,
            num_classes=len(args.rating_vals),
            num_basis=args.gen_r_num_basis_func,
        )

    def forward(self, enc_graph, dec_graph, ufeat, ifeat):
        user_out, movie_out = self.encoder(enc_graph, ufeat, ifeat)
        pred_ratings = self.decoder(dec_graph, user_out, movie_out)
        return pred_ratings


def evaluate(args, net, dataset, segment="valid"):
    possible_rating_values = dataset.possible_rating_values
    nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(
        args.device
    )

    if segment == "valid":
        rating_values = dataset.valid_truths
        enc_graph = dataset.valid_enc_graph
        dec_graph = dataset.valid_dec_graph
    elif segment == "test":
        rating_values = dataset.test_truths
        enc_graph = dataset.test_enc_graph
        dec_graph = dataset.test_dec_graph
    else:
        raise NotImplementedError

    # Evaluate RMSE
    net.eval()
    with th.no_grad():
        pred_ratings = net(
            enc_graph, dec_graph, dataset.user_feature, dataset.movie_feature
        )
    real_pred_ratings = (
        th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)
    ).sum(dim=1)
    rmse = ((real_pred_ratings - rating_values) ** 2.0).mean().item()
    rmse = np.sqrt(rmse)
    return rmse


def train(args):
    print(args)
    dataset = MovieLens(
        args.data_name,
        args.device,
        use_one_hot_fea=args.use_one_hot_fea,
        symm=args.gcn_agg_norm_symm,
        test_ratio=args.data_test_ratio,
        valid_ratio=args.data_valid_ratio,
    )
    print("Loading data finished ...\n")

    args.src_in_units = dataset.user_feature_shape[1]
    args.dst_in_units = dataset.movie_feature_shape[1]
    args.rating_vals = dataset.possible_rating_values

    ### build the net
    net = Net(args=args)
    net = net.to(args.device)
    nd_possible_rating_values = th.FloatTensor(
        dataset.possible_rating_values
    ).to(args.device)
    rating_loss_net = nn.CrossEntropyLoss()
    learning_rate = args.train_lr
    optimizer = get_optimizer(args.train_optimizer)(
        net.parameters(), lr=learning_rate
    )
    print("Loading network finished ...\n")

    ### perpare training data
    train_gt_labels = dataset.train_labels
    train_gt_ratings = dataset.train_truths

    ### prepare the logger
    train_loss_logger = MetricLogger(
        ["iter", "loss", "rmse"],
        ["%d", "%.4f", "%.4f"],
        os.path.join(args.save_dir, "train_loss%d.csv" % args.save_id),
    )
    valid_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "valid_loss%d.csv" % args.save_id),
    )
    test_loss_logger = MetricLogger(
        ["iter", "rmse"],
        ["%d", "%.4f"],
        os.path.join(args.save_dir, "test_loss%d.csv" % args.save_id),
    )

    ### declare the loss information
    best_valid_rmse = np.inf
    no_better_valid = 0
    best_iter = -1
    count_rmse = 0
    count_num = 0
    count_loss = 0

    dataset.train_enc_graph = dataset.train_enc_graph.int().to(args.device)
    dataset.train_dec_graph = dataset.train_dec_graph.int().to(args.device)
    dataset.valid_enc_graph = dataset.train_enc_graph
    dataset.valid_dec_graph = dataset.valid_dec_graph.int().to(args.device)
    dataset.test_enc_graph = dataset.test_enc_graph.int().to(args.device)
    dataset.test_dec_graph = dataset.test_dec_graph.int().to(args.device)

    print("Start training ...")
    dur = []
    for iter_idx in range(1, args.train_max_iter):
        if iter_idx > 3:
            t0 = time.time()
        net.train()
        pred_ratings = net(
            dataset.train_enc_graph,
            dataset.train_dec_graph,
            dataset.user_feature,
            dataset.movie_feature,
        )
        loss = rating_loss_net(pred_ratings, train_gt_labels).mean()
        count_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
        optimizer.step()

        if iter_idx > 3:
            dur.append(time.time() - t0)

        if iter_idx == 1:
            print("Total #Param of net: %d" % (torch_total_param_num(net)))
            print(
                torch_net_info(
                    net,
                    save_path=os.path.join(
                        args.save_dir, "net%d.txt" % args.save_id
                    ),
                )
            )

        real_pred_ratings = (
            th.softmax(pred_ratings, dim=1)
            * nd_possible_rating_values.view(1, -1)
        ).sum(dim=1)
        rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum()
        count_rmse += rmse.item()
        count_num += pred_ratings.shape[0]

        if iter_idx % args.train_log_interval == 0:
            train_loss_logger.log(
                iter=iter_idx,
                loss=count_loss / (iter_idx + 1),
                rmse=count_rmse / count_num,
            )
            logging_str = (
                "Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format(
                    iter_idx,
                    count_loss / iter_idx,
                    count_rmse / count_num,
                    np.average(dur),
                )
            )
            count_rmse = 0
            count_num = 0

        if iter_idx % args.train_valid_interval == 0:
            valid_rmse = evaluate(
                args=args, net=net, dataset=dataset, segment="valid"
            )
            valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
            logging_str += ",\tVal RMSE={:.4f}".format(valid_rmse)

            if valid_rmse < best_valid_rmse:
                best_valid_rmse = valid_rmse
                no_better_valid = 0
                best_iter = iter_idx
                test_rmse = evaluate(
                    args=args, net=net, dataset=dataset, segment="test"
                )
                best_test_rmse = test_rmse
                test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
                logging_str += ", Test RMSE={:.4f}".format(test_rmse)
            else:
                no_better_valid += 1
                if (
                    no_better_valid > args.train_early_stopping_patience
                    and learning_rate <= args.train_min_lr
                ):
                    logging.info(
                        "Early stopping threshold reached. Stop training."
                    )
                    break
                if no_better_valid > args.train_decay_patience:
                    new_lr = max(
                        learning_rate * args.train_lr_decay_factor,
                        args.train_min_lr,
                    )
                    if new_lr < learning_rate:
                        learning_rate = new_lr
                        logging.info("\tChange the LR to %g" % new_lr)
                        for p in optimizer.param_groups:
                            p["lr"] = learning_rate
                        no_better_valid = 0
        if iter_idx % args.train_log_interval == 0:
            print(logging_str)
    print(
        "Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
            best_iter, best_valid_rmse, best_test_rmse
        )
    )
    train_loss_logger.close()
    valid_loss_logger.close()
    test_loss_logger.close()


def config():
    parser = argparse.ArgumentParser(description="GCMC")
    parser.add_argument("--seed", default=123, type=int)
    parser.add_argument(
        "--device",
        default="0",
        type=int,
        help="Running device. E.g `--device 0`, if using cpu, set `--device -1`",
    )
    parser.add_argument("--save_dir", type=str, help="The saving directory")
    parser.add_argument("--save_id", type=int, help="The saving log id")
    parser.add_argument("--silent", action="store_true")
    parser.add_argument(
        "--data_name",
        default="ml-1m",
        type=str,
        help="The dataset name: ml-100k, ml-1m, ml-10m",
    )
    parser.add_argument(
        "--data_test_ratio", type=float, default=0.1
    )  ## for ml-100k the test ration is 0.2
    parser.add_argument("--data_valid_ratio", type=float, default=0.1)
    parser.add_argument("--use_one_hot_fea", action="store_true", default=False)
    parser.add_argument("--model_activation", type=str, default="leaky")
    parser.add_argument("--gcn_dropout", type=float, default=0.7)
    parser.add_argument("--gcn_agg_norm_symm", type=bool, default=True)
    parser.add_argument("--gcn_agg_units", type=int, default=500)
    parser.add_argument("--gcn_agg_accum", type=str, default="sum")
    parser.add_argument("--gcn_out_units", type=int, default=75)
    parser.add_argument("--gen_r_num_basis_func", type=int, default=2)
    parser.add_argument("--train_max_iter", type=int, default=2000)
    parser.add_argument("--train_log_interval", type=int, default=1)
    parser.add_argument("--train_valid_interval", type=int, default=1)
    parser.add_argument("--train_optimizer", type=str, default="adam")
    parser.add_argument("--train_grad_clip", type=float, default=1.0)
    parser.add_argument("--train_lr", type=float, default=0.01)
    parser.add_argument("--train_min_lr", type=float, default=0.001)
    parser.add_argument("--train_lr_decay_factor", type=float, default=0.5)
    parser.add_argument("--train_decay_patience", type=int, default=50)
    parser.add_argument(
        "--train_early_stopping_patience", type=int, default=100
    )
    parser.add_argument("--share_param", default=False, action="store_true")

    args = parser.parse_args()
    args.device = (
        th.device(args.device) if args.device >= 0 else th.device("cpu")
    )

    ### configure save_fir to save all the info
    if args.save_dir is None:
        args.save_dir = (
            args.data_name
            + "_"
            + "".join(
                random.choices(string.ascii_uppercase + string.digits, k=2)
            )
        )
    if args.save_id is None:
        args.save_id = np.random.randint(20)
    args.save_dir = os.path.join("log", args.save_dir)
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)

    return args


if __name__ == "__main__":
    args = config()
    np.random.seed(args.seed)
    th.manual_seed(args.seed)
    if th.cuda.is_available():
        th.cuda.manual_seed_all(args.seed)
    train(args)
