import os
# os.system('pip install torch==1.8.0 --user')
os.system("pip install lmdb")
os.system("pip install prefetch_generator")
os.system('pip install tensorboard_logger')
import torch
import torch.nn as nn
import tqdm
import argparse
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.utils.data import DataLoader, Subset
from gen_dataset.avazu import AvazuDataset
from gen_dataset.criteo import CriteoDataset
from gen_dataset.avito import AvitoDataset
from gen_dataset.movielens import MovieLens25MDataset, MovieLens20MDataset, MovieLens1MDataset
from gen_dataset.kdd import Kdd12Dataset
from loss import AUCExponential_loss, AUCPR_loss, AUCPolyloss, AUCSquare_loss, AUCMLoss, SimilarityLoss, AUCLogistic_loss, PESG, AUCHinge_loss, AUCPointwiseHinge_loss, AUCCircle_loss, AUCUnivariate_Loss, AUCPointloss, AUCPointlossV2
from loss import IF_AUCExponential_loss, IF_AUCSquare_loss
from loss import AUCSimple_loss, AUCCosloss, AUCDML_Loss, AUCTP_at_FP_loss, AUCComp_loss
from loss import FocalLoss
from loss import Polyloss
from loss import CompositionalLoss, PDSCA
from loss import APLoss, APLoss_SH_V1, SOAP
from utils import calc_auc, collate_fn
import tensorboard_logger as tb_logger
from model.dcn import DCN
from model.dnn import DNN
from model.autoInt import AutoInt
from model.pnn import PNN
from model.deepfm import DeepFM, FM

from prefetch_generator import BackgroundGenerator

class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

def get_dataset(path, use_group=1, mode='train',cache_path=None):
    if 'avazu' in path:
        return AvazuDataset(path, cache_path=cache_path, mode=mode)
    elif 'kdd' in path:
        return Kdd12Dataset(path, mode=mode)
    elif 'avito' in path:
        return AvitoDataset(path, cache_path=cache_path, mode=mode)
    else:
        raise ValueError('unknown dataset name')

global mse, bce
mse = nn.MSELoss().cuda()
bce = torch.nn.BCEWithLogitsLoss().cuda()
pl = Polyloss().cuda()
def rmse(pred, target):
    loss = torch.sqrt(mse(pred, target))
    return loss

def knowledge_distillation(pred, target, gt):
    kl_loss = bce(pred, torch.sigmoid(target))
    return kl_loss

from torch.nn.functional import relu
def rank_confidence(pred, target, gt, margin=0.0, loss_type='hinge'):
    if loss_type == 'hinge':
        loss = gt * relu(target.sigmoid() + margin - pred.sigmoid()) + (1 - gt) * relu(margin + pred.sigmoid() - target.sigmoid())
    elif loss_type == 'poly':
        loss = gt * (1 - torch.sigmoid(pred - target)) + (1 - gt) * torch.sigmoid(pred - target)
    elif loss_type == 'log':
        loss = -gt * torch.log(torch.sigmoid(pred-target)) - (1-gt)*torch.log(1-torch.sigmoid(pred-target))
    elif loss_type == 'mse':
        loss = gt * (margin - (pred - target)) ** 2 + (1 - gt) * (margin - (target - pred)) ** 2
    else:
        raise ValueError('loss type is not exist', loss_type)
    return loss.mean()


def relational_rank_confidence(pred, target, gt, margin=0.3, loss_type='hinge'):
    pos_pred = pred[gt==1]
    pos_target = target[gt==1]

    neg_pred = pred[gt==0]
    neg_target = target[gt==0]

    pred_distance = {}
    target_distance = {}

    # pred_distance['p']  = pos_pred.unsqueeze(0) - pos_pred.unsqueeze(1)
    # target_distance['p']  = pos_target.unsqueeze(0) - pos_target.unsqueeze(1)

    # pred_distance['n']  = neg_pred.unsqueeze(0) - neg_pred.unsqueeze(1)
    # target_distance['n']  = neg_target.unsqueeze(0) - neg_target.unsqueeze(1)

    pred_distance['pn'] = pos_pred.unsqueeze(0) - neg_pred.unsqueeze(1)
    target_distance['pn'] = pos_target.unsqueeze(0) - neg_target.unsqueeze(1)

    if loss_type == 'hinge':
        pn_loss = relu(margin - pred_distance['pn'] + target_distance['pn'])
        # pp_loss = relu(margin + pred_distance['p'] - target_distance['p'])
        # nn_loss = relu(margin + pred_distance['n'] - target_distance['n'])
    elif loss_type == 'log':
        pn_loss = -torch.log(torch.sigmoid(pred_distance['pn'] - target_distance['pn']))
        # pp_loss = -torch.log(1 - torch.sigmoid(pred_distance['p'] - target_distance['p']))
        # nn_loss = -torch.log(1 - torch.sigmoid(pred_distance['n'] - target_distance['n']))
    elif loss_type == 'mse':
        pn_loss = (margin - pred_distance['pn'] + target_distance['pn']) ** 2
        # pp_loss = (margin + pred_distance['p'] - target_distance['p']) ** 2
        # nn_loss = (margin + pred_distance['n'] - target_distance['n']) ** 2
    else:
        raise ValueError('no that loss type', loss_type)
    return pn_loss.mean() #+ pp_loss.mean() + nn_loss.mean()


def train(model, optimizer, data_loader, criterion, valid_prediction=None, margin=0.2):
    cnt=0
    model.train()
    targets, predicts = list(), list()
    total_loss = 0
    total_samples = 0
    loss_type = args.loss_type
    print('use loss_type: %s' % loss_type)
    for index, fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
        fields, target = fields.cuda(async=True), target.cuda(async=True)
        y = model(fields)
        logloss = criterion(y, target.float())

        if valid_prediction is not None:
            previous_y = valid_prediction[index]
            previous_y = previous_y.clamp(-10, 10)
            rc_loss = rank_confidence(y, previous_y, target, margin=1, loss_type=loss_type)
            rrc_loss = relational_rank_confidence(y, previous_y, target, margin=1, loss_type=loss_type)
            loss = logloss + rc_loss * 0.2 + rrc_loss * 0.2
        else:
            loss = logloss

        total_loss    += logloss.item() * len(y)
        total_samples += len(y)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        targets.extend(target.tolist())
        predicts.extend(torch.sigmoid(y).tolist())
        cnt += 1
    
    true_pred = np.sum(np.array(predicts) * np.array(targets)) / np.sum(targets)
    false_pred = np.sum(np.array(predicts) * ( 1 - np.array(targets))) / np.sum(1 - np.array(targets))

    print('training margin is', true_pred - false_pred, 'false pred avg mean:', np.mean(false_pred), 'true pred avg mean:', np.mean(true_pred), 'training loss:', total_loss / total_samples)
    return total_loss/total_samples, np.mean(false_pred), np.mean(true_pred)


def meta_train(model, optimizer, data_loader_dict, train_day, pretrain_days, criterion):
    pretrain_days = pretrain_days
    valid_prediction = None
    for idx, day in enumerate(train_day):
        data_loader = data_loader_dict[day]
        print('begin offline training on %s' % str(day))
        loss, fp_mean, tp_mean = train(model, optimizer, data_loader, criterion, valid_prediction)
        logger.log_value('train_loss', loss, idx)
        logger.log_value('fp_mean', fp_mean, idx)
        logger.log_value('tp_mean', tp_mean, idx)
        logger.log_value('margin', tp_mean - fp_mean, idx)
        if idx >= pretrain_days and idx < len(train_day) - 1:
            print('begin offline validate on day %s' % str(train_day[idx+1]))
            auc, valid_prediction = test(model, data_loader_dict[train_day[idx+1]], criterion, is_prediction=True)
            logger.log_value('valid_auc', auc, idx)
        

def test(model, data_loader, criterion, is_prediction=False):
    model.eval()
    targets, predicts = list(), list()
    total_loss = 0
    total_rmse_loss = 0
    total_samples = 0
    with torch.no_grad():
        valid_prediction = torch.zeros(data_loader.dataset.__len__()).cuda()
        for index, fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.cuda(async=True), target.cuda(async=True)
            y = model(fields, False)
            loss = criterion(y, target.float())
            rmse_loss = rmse(torch.sigmoid(y), target.float())
            total_loss      += loss.item()*len(y)
            total_rmse_loss += rmse_loss.item() * len(y)
            total_samples   += len(y)

            targets.extend(target.tolist())
            predicts.extend(y.tolist())
            valid_prediction[index] = y.detach()
        auc = roc_auc_score(targets, predicts)
        auprc = average_precision_score(targets, predicts)
    if is_prediction:
        print('validation auc:', auc)
        return valid_prediction
    return auc, auprc, total_loss/total_samples, total_rmse_loss/total_samples

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def main(args):
    device = torch.device(args.device)
    print('begin build dataset')
    if args.dataset_path == 'avazu':
        online_day = list(range(141021, 141031))
        train_day = online_day[:-1]
        pretrain_days = 3
        data_root_path = '/home/avazu/'
    elif args.dataset_path == 'avito':
        online_day = ['2015-04-25', '2015-04-26', '2015-04-27', '2015-04-28',
            '2015-04-29', '2015-04-30', '2015-05-01', '2015-05-02',
            '2015-05-03', '2015-05-04', '2015-05-05', '2015-05-06',
            '2015-05-07', '2015-05-08', '2015-05-09', '2015-05-10',
            '2015-05-11', '2015-05-12', '2015-05-13', '2015-05-14',
            '2015-05-15', '2015-05-16', '2015-05-17', '2015-05-18',
            '2015-05-19', '2015-05-20']
        train_day = online_day[:-1]
        pretrain_days = 3
        data_root_path = '/home/data/avito/'

    global logger
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    dataset_dict = dict()
    for day in online_day:
        dataset_dict[day] = get_dataset(
                                args.dataset_path,
                                use_group=0,
                                cache_path=os.path.join(data_root_path, '.'+str(day)),
                                mode=None
                            )
        field_dims = dataset_dict[day].field_dims

    dataloader_dict = dict()
    for day in online_day:
        dataloader_dict[day] = DataLoaderX(dataset_dict[day], batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True)

    test_data_loader = dataloader_dict[online_day[-1]]
    print('begin generate models')
    if args.models == 'dcn':
        model = DCN(field_dims, args.batch_size, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'fm':
        model = FM(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    elif args.models == 'dnn':
        model = DNN(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'pnn':
        model = PNN(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'autoint':
        model = AutoInt(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    elif args.models == 'deepfm':
        model = DeepFM(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    else:
        raise ValueError('no that models %s' % args.model)    
    print(model)
    print('params are', count_parameters(model))
    print('build DataParallel')
    model = model.cuda()

    torch.backends.cudnn.benchmark = True
    criterion = torch.nn.BCEWithLogitsLoss().cuda()

    if args.optim == 'adam': 
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.wdcy
        )                          
    elif args.optim == 'adagrad':
        optimizer = torch.optim.Adagrad(
            model.parameters(), lr=args.lr, weight_decay=args.wdcy
        )   

    print('begin training')
    best_loss = 1e10  
    best_auc = 0
    test_auc = 0
    test_loss = 1e10  
    early_stopping = 0
    
    meta_train(model, optimizer, dataloader_dict, train_day, pretrain_days, criterion)
    auc, auprc, loss, rmse_loss = test(model, test_data_loader, criterion)
    return (auc, auprc, loss, rmse_loss)        



if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description="Train and test the field-wise learning model")
    parser.add_argument('--dataset_path', type=str, default='/home/data/ml-20m/ratings.csv', help="path to the dataset")
    parser.add_argument('--gpus',type=int,nargs='+',default=None,help='gpus')
    parser.add_argument('--ebd_dim', type=int, default=20, help="embedding dimension")
    parser.add_argument('--log_ebd', action='store_true', default=False, help="whether to use log scale for embedding dimensions")    
    parser.add_argument('--lr', type=float, default=1e-3, help="learning rate for AdaGrad optimiser")
    parser.add_argument('--wdcy', type=float, default=1e-8, help="weight decay")  
    parser.add_argument('--dropout', type=float, default=0.2, help="dropout ratio")  
    parser.add_argument('--margin', type=float, default=0.2, help="margin")  
    parser.add_argument('--alpha', type=float, default=0.2, help="loss ratio")  
    parser.add_argument('--models', type=str, default='dcn')  
    parser.add_argument('--layers', type=list, default=[400, 400, 400])  
    parser.add_argument('--dnn_layers', type=int, nargs='+', default=[1024, 512, 256])  
    parser.add_argument('--optim', type=str, default='adagrad')  
    parser.add_argument('--use_bn', type=int, default=1, help='use bn')
    parser.add_argument('--use_mpn', type=int, default=1, help='use mbr module')
    parser.add_argument('--use_topk', type=int, default=1, help='use topk')
    parser.add_argument('--use_group', type=int, default=0, help='use group')
    parser.add_argument('--use_am', type=int, default=1, help='use auc maximumzation')
    parser.add_argument('--early_stopping', type=int, default=1, help='use early_stopping')

    parser.add_argument('--loss_type', type=str, default='log', choices=['log', 'mse', 'hinge', 'poly'], help="batch size")   
    parser.add_argument('--batch_size', type=int, default=2048, help="batch size")
    parser.add_argument('--running_times', type=int, default=1, help="batch size")
    parser.add_argument('--epoch', type=int, default=40, help="max training epochs")
    parser.add_argument('--device', type=str, default="cuda:0", help="device to use")                   
    global args
    args = parser.parse_args()
    print(args)    
    
    test_auc = []
    test_ap = []
    test_loss = []
    test_rmse_loss = []
    
    for i in range(args.running_times):
        args.tb_folder = os.path.join('models', '_'.join([args.models, str(args.batch_size), str(args.lr), str(i)]))
        test_results  = main(args)
        print("best_auc: {0}, best_ap:{1}, best_loss: {2}, best_rmse_loss:{3}".format(test_results[0],test_results[1],test_results[2], test_results[3]))
        test_auc.append(test_results[0])
        test_ap.append(test_results[1])
        test_loss.append(test_results[2])
        test_rmse_loss.append(test_results[3])
    
    print('mean and std auc is', np.mean(test_auc), np.std(test_auc))
    print('mean and std ap is', np.mean(test_ap), np.std(test_ap))
    print('mean and std loss is', np.mean(test_loss), np.std(test_loss))
    print('mean and std rmse_loss is', np.mean(test_rmse_loss), np.std(test_rmse_loss))
    
    
