from email.policy import default
import os
# os.system('pip install torch==1.8.0 --user')
os.system("pip install lmdb")
os.system("pip install prefetch_generator")
os.system('pip install tensorboard_logger')

import itertools
import random
import torch
import torch.nn as nn
import tqdm
import copy
import argparse
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.utils.data import DataLoader, Subset
from gen_dataset.avazu import AvazuDataset
from gen_dataset.criteo import CriteoDataset
from gen_dataset.avito import AvitoDataset
from gen_dataset.movielens import MovieLens25MDataset, MovieLens20MDataset, MovieLens1MDataset
from gen_dataset.kdd import Kdd12Dataset
from loss import AUCExponential_loss, AUCPR_loss, AUCPolyloss, AUCSquare_loss, AUCMLoss, SimilarityLoss, AUCLogistic_loss, PESG, AUCHinge_loss, AUCPointwiseHinge_loss, AUCCircle_loss, AUCUnivariate_Loss, AUCPointloss, AUCPointlossV2
from loss import IF_AUCExponential_loss, IF_AUCSquare_loss
from loss import AUCSimple_loss, AUCCosloss, AUCDML_Loss, AUCTP_at_FP_loss, AUCComp_loss
from loss import FocalLoss
from loss import Polyloss
from loss import CompositionalLoss, PDSCA
from loss import APLoss, APLoss_SH_V1, SOAP
from utils import calc_auc, collate_fn
import tensorboard_logger as tb_logger
from model.dcn import DCN
from model.dnn import DNN
from model.autoInt import AutoInt
from model.pnn import PNN
from model.deepfm import DeepFM, FM

from prefetch_generator import BackgroundGenerator

class DataLoaderX(DataLoader):
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())

def get_dataset(path, use_group=1, mode='train',cache_path=None):
    if 'avazu' in path:
        return AvazuDataset(path, cache_path=cache_path, mode=mode)
    elif 'kdd' in path:
        return Kdd12Dataset(path, mode=mode)
    elif 'avito' in path:
        return AvitoDataset(path, cache_path=cache_path, mode=mode)
    else:
        raise ValueError('unknown dataset name')

global mse
mse = nn.MSELoss().cuda()
pl = Polyloss().cuda()
def rmse(pred, target):
    loss = torch.sqrt(mse(pred, target))
    return loss


def train(model, optimizer, data_loader, criterion, margin=0.2):
    cnt=0
    model.train()
    targets, predicts = list(), list()
    total_loss = 0
    total_samples = 0

    for index, fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
        fields, target = fields.cuda(async=True), target.cuda(async=True)
        y = model(fields)
        logloss = criterion(y, target.float())
        
        loss = logloss

        total_loss    += logloss.item() * len(y)
        total_samples += len(y)

        model.zero_grad()
        loss.backward()
        optimizer.step()

        targets.extend(target.tolist())
        predicts.extend(torch.sigmoid(y).tolist())
        cnt += 1
    
    true_pred = np.sum(np.array(predicts) * np.array(targets)) / np.sum(targets)
    false_pred = np.sum(np.array(predicts) * ( 1 - np.array(targets))) / np.sum(1 - np.array(targets))

    print('training margin is', true_pred - false_pred, 'false pred avg mean:', np.mean(false_pred), 'true pred avg mean:', np.mean(true_pred))
    return total_loss/total_samples, np.mean(false_pred), np.mean(true_pred)


def meta_train(model, optimizer, data_loader_dict, train_day, criterion):
    
    for idx, day in enumerate(train_day):
        data_loader = data_loader_dict[day]
        print('begin training on %s' % str(day))
        loss, fp_mean, tp_mean = train(model, optimizer, data_loader, criterion)
        logger.log_value('train_loss', loss, idx)
        logger.log_value('fp_mean', fp_mean, idx)
        logger.log_value('tp_mean', tp_mean, idx)
        logger.log_value('margin', tp_mean - fp_mean, idx)
        
    

def test(model, data_loader, criterion):
    model.eval()
    targets, predicts = list(), list()
    total_loss = 0
    total_rmse_loss = 0
    total_samples = 0
    with torch.no_grad():
        for index, fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.cuda(async=True), target.cuda(async=True)
            y = model(fields, False)
            loss = criterion(y, target.float())
            rmse_loss = rmse(torch.sigmoid(y), target.float())

            total_loss      += loss.item()*len(y)
            total_rmse_loss += rmse_loss.item() * len(y)
            total_samples   += len(y)

            targets.extend(target.tolist())
            predicts.extend(y.tolist())
        auc = roc_auc_score(targets, predicts)
        auprc = average_precision_score(targets, predicts)
    print('validation auc:', auc)
    return auc, auprc, total_loss/total_samples, total_rmse_loss/total_samples

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def main(args):
    device = torch.device(args.device)
    print('begin build dataset')
    if args.dataset_path == 'avazu':
        online_day = list(range(141021, 141031))
        train_day = online_day[:-1]
        data_root_path = '/home/avazu/'
    elif args.dataset_path == 'avito':
        online_day = ['2015-04-25', '2015-04-26', '2015-04-27', '2015-04-28',
            '2015-04-29', '2015-04-30', '2015-05-01', '2015-05-02',
            '2015-05-03', '2015-05-04', '2015-05-05', '2015-05-06',
            '2015-05-07', '2015-05-08', '2015-05-09', '2015-05-10',
            '2015-05-11', '2015-05-12', '2015-05-13', '2015-05-14',
            '2015-05-15', '2015-05-16', '2015-05-17', '2015-05-18',
            '2015-05-19', '2015-05-20']
        train_day = online_day[:-1]
        data_root_path = '/home/data/avito/'

    global logger
    logger = tb_logger.Logger(logdir=args.tb_folder, flush_secs=2)

    dataset_dict = dict()
    for day in online_day:
        dataset_dict[day] = get_dataset(
                                args.dataset_path,
                                use_group=0,
                                cache_path=os.path.join(data_root_path, '.'+str(day)),
                                mode=None
                            )
        field_dims = dataset_dict[day].field_dims

    dataloader_dict = dict()
    for day in online_day:
        dataloader_dict[day] = DataLoaderX(dataset_dict[day], batch_size=args.batch_size, shuffle=True, num_workers=12, pin_memory=True)

    test_data_loader = dataloader_dict[online_day[-1]]
    print('begin generate models')
    if args.models == 'dcn':
        model = DCN(field_dims, args.batch_size, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'fm':
        model = FM(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    elif args.models == 'dnn':
        model = DNN(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'pnn':
        model = PNN(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout, use_bn=args.use_bn)
    elif args.models == 'autoint':
        model = AutoInt(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    elif args.models == 'deepfm':
        model = DeepFM(field_dims, args.ebd_dim, dnn_layers=args.dnn_layers, dropout=args.dropout)
    else:
        raise ValueError('no that models %s' % args.model)    
    print(model)
    print('params are', count_parameters(model))
    print('build DataParallel')
    model = model.cuda()

    torch.backends.cudnn.benchmark = True
    criterion = torch.nn.BCEWithLogitsLoss().cuda()

    if args.optim == 'adam': 
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.wdcy
        )                          
    elif args.optim == 'adagrad':
        optimizer = torch.optim.Adagrad(
            model.parameters(), lr=args.lr, weight_decay=args.wdcy
        )   

    print('begin training')
    best_loss = 1e10  
    best_auc = 0
    test_auc = 0
    test_loss = 1e10  
    early_stopping = 0
    
    meta_train(model, optimizer, dataloader_dict, train_day, criterion)
    auc, auprc, loss, rmse_loss = test(model, test_data_loader, criterion)
    return (auc, auprc, loss, rmse_loss)        



if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description="Train and test the field-wise learning model")
    parser.add_argument('--dataset_path', type=str, default='/home/data/ml-20m/ratings.csv', help="path to the dataset")
    parser.add_argument('--gpus',type=int,nargs='+',default=None,help='gpus')
    parser.add_argument('--ebd_dim', type=int, default=20, help="embedding dimension")
    parser.add_argument('--log_ebd', action='store_true', default=False, help="whether to use log scale for embedding dimensions")    
    parser.add_argument('--lr', type=float, default=1e-3, help="learning rate for AdaGrad optimiser")
    parser.add_argument('--wdcy', type=float, default=1e-8, help="weight decay")  
    parser.add_argument('--dropout', type=float, default=0.2, help="dropout ratio")  
    parser.add_argument('--margin', type=float, default=0.2, help="margin")  
    parser.add_argument('--alpha', type=float, default=0.2, help="dropout ratio")  
    parser.add_argument('--models', type=str, default='dcn')  
    parser.add_argument('--layers', type=list, default=[400, 400, 400])  
    parser.add_argument('--dnn_layers', type=int, nargs='+', default=[1024, 512, 256])  
    parser.add_argument('--optim', type=str, default='adagrad')  
    parser.add_argument('--use_bn', type=int, default=1, help='use bn')
    parser.add_argument('--use_mpn', type=int, default=1, help='use mbr module')
    parser.add_argument('--use_topk', type=int, default=1, help='use topk')
    parser.add_argument('--use_group', type=int, default=0, help='use group')
    parser.add_argument('--use_am', type=int, default=1, help='use auc maximumzation')
    parser.add_argument('--early_stopping', type=int, default=1, help='use early_stopping')

    
    parser.add_argument('--batch_size', type=int, default=2048, help="batch size")
    parser.add_argument('--running_times', type=int, default=1, help="batch size")
    parser.add_argument('--epoch', type=int, default=40, help="max training epochs")
    parser.add_argument('--device', type=str, default="cuda:0", help="device to use")                   
    global args
    args = parser.parse_args()
    print(args)    
    
    test_auc = []
    test_ap = []
    test_loss = []
    test_rmse_loss = []
    
    for i in range(args.running_times):
        args.tb_folder = os.path.join('models', '_'.join([args.models, str(args.batch_size), str(args.lr), str(i)]))
        test_results  = main(args)
        print("best_auc: {0}, best_ap:{1}, best_loss: {2}, best_rmse_loss:{3}".format(test_results[0],test_results[1],test_results[2], test_results[3]))
        test_auc.append(test_results[0])
        test_ap.append(test_results[1])
        test_loss.append(test_results[2])
        test_rmse_loss.append(test_results[3])
    
    print('mean and std auc is', np.mean(test_auc), np.std(test_auc))
    print('mean and std ap is', np.mean(test_ap), np.std(test_ap))
    print('mean and std loss is', np.mean(test_loss), np.std(test_loss))
    print('mean and std rmse_loss is', np.mean(test_rmse_loss), np.std(test_rmse_loss))
    
    
