from model import PlainGCN
import nni
from nb_dataset_101 import Nb101Dataset
from nb_dataset_201 import Nb201Dataset
from tb_dataset_101 import Trans101Dataset
from  macro_dataset import Trans101DatasetMacro
from nb_dataset_nlp import NBNLPDataset
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from utils import (AverageMeter,accuracy_mse,to_cuda,
                   set_seed,BPRLoss,list_mle,pair_loss,
                   top_k_best_rank,weighted_pair_loss,ListNetLoss,
                   average_rank_of_top_10,rel_at_1,wpair2,warp,mp_loss,
                  lambdaLoss,margin_loss)
from scipy.stats import kendalltau,weightedtau
import time

def train(train_set,train_loader,model,optimizer,lr_scheduler,criterion,epochs,device,verbose=False):

    for epoch in range(epochs):
        model.train()
        lr = optimizer.param_groups[0]["lr"]
        losses=AverageMeter("loss")
        mses=AverageMeter("mse")

        predicts, targets = [], []
        for step,batch in enumerate(train_loader):
            batch = to_cuda(batch,device)
            target = batch["n_val_acc"]
            optimizer.zero_grad()
            predict = model(batch)
            if criterion == 'mse':
                cri = nn.MSELoss()
                loss = cri(predict, target.float())
                loss.backward()
                optimizer.step()
                predicts.append(predict.cpu().detach().numpy())
                targets.append(target.cpu().detach().numpy())
                mse = accuracy_mse(predict.squeeze(),target.squeeze(),train_set)
                losses.update(loss.item(),target.size(0))
                mses.update(mse.item(),target.size(0))
            else:
                if criterion == 'bpr':
                    cri=BPRLoss()
                    loss = cri(predict, target.float())
                if criterion == 'listmle':
                    loss = list_mle(predict,target.float())
                if criterion == 'pair':
                    loss = pair_loss(predict,target.float(),device=device)
                if criterion == 'wpair':
                    loss = weighted_pair_loss(predict,target.float(),device=device)
                if criterion == 'listnet':
                    cri = ListNetLoss()
                    loss = cri(predict, target.float())
                if criterion == 'wpair2':
                    loss = wpair2(predict, target.float())
                if criterion == 'warp':
                    loss = warp(predict, target.float())
                if criterion == 'mp':
                    loss = mp_loss(predict, target.float(), device=device)
                if criterion == 'lambda':
                    loss = lambdaLoss(predict, target.float(),k=int(target.size(0)*0.4))
                if criterion == 'margin':
                    loss = margin_loss(predict, target.float(),device=device)

                loss.backward()

                optimizer.step()
                predicts.append(predict.cpu().detach().numpy())
                targets.append(target.cpu().detach().numpy())
                losses.update(loss.item(), target.size(0))

        predicts = np.concatenate(predicts)
        targets = np.concatenate(targets)
        kendall_tau = kendalltau(predicts, targets)[0]
        print('epoch:',epoch,'loss:',loss.item(),'ktau:',kendall_tau)

        lr_scheduler.step()
        if verbose:
            if epoch%1==0:
                if criterion == 'mse':
                    print("Epoch:",epoch + 1,"lr:",lr,"loss",losses.avg,"ktau:",kendall_tau)
                else:
                    print("Epoch:", epoch + 1, "lr:", lr, "loss", losses.avg, "ktau:", kendall_tau)

    return model

def evaluate(test_set,test_loader,model,criterion,device):
    model.eval()
    print("start eval...")
    predicts,targets = [],[]
    losses = AverageMeter("loss")
    mses = AverageMeter("mse")
    with torch.no_grad():
        for step,batch in enumerate(test_loader):
            batch = to_cuda(batch,device)
            target = batch["n_test_acc"]
            predict = model(batch)
            predicts.append(predict.cpu().numpy())
            targets.append(target.cpu().numpy())
            if (step) % 10==0:
                if (criterion == 'mse'): print("step:",step)
                else: print("step:",step)

    predicts = np.concatenate(predicts)
    targets = np.concatenate(targets)
    kendall_tau=kendalltau(predicts,targets)[0]
    wtau = weightedtau(predicts,targets)[0]
    top5 = top_k_best_rank(predicts, targets, 5)
    top10 = top_k_best_rank(predicts, targets, 10)
    rank10 = average_rank_of_top_10(predicts,targets)
    ref5 = rel_at_1(predicts,test_set.denormalize(targets),5)
    ref10 = rel_at_1(predicts,test_set.denormalize(targets),10)
    print("top5:",top5,"top10:",top10,"rank10:",rank10,"ref5:",ref5,"ref10:",ref10,"wtau:",wtau)
    return kendall_tau,predicts,targets,losses.avg,mses.avg,wtau,top5,top10,ref5,ref10,rank10


def get_params():
    parser = ArgumentParser()
    # exp and dataset
    parser.add_argument("--exp_name", type=str, default='rank')
    parser.add_argument("--bench", type=str, default='101',choices=['101','201','TB101','Macro','NLP'])
    parser.add_argument("--train_split", type=int, default=100)
    parser.add_argument("--test_split", type=str, default='all')
    parser.add_argument("--dataset", type=str, default='cifar10',choices=['cifar10','cifar100','imagenet16'])
    # training settings
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--epochs", default=100)
    parser.add_argument("--task", default='class_object', type=str, choices=['image classification','class_scene','class_object','jigsaw','segmentsemantic','normal','autoencoder','room_layout'])
    parser.add_argument("--loss", default='mse', type=str,choices=['mse','bpr','listmle','pair','wpair','listnet','wpair2','warp','mp','lambda','margin'])
    parser.add_argument("--layers", default=4, type=int)
    parser.add_argument("--hidden_size",default=144,type=int)
    parser.add_argument("--p-dim",default=128,type=int)
    parser.add_argument("--lr", default=1e-3, type=float)
    parser.add_argument("--wd", default=1e-3, type=float)
    parser.add_argument("--dropout", default=0.15, type=float)
    parser.add_argument("--train_batch_size", default=10, type=int)
    parser.add_argument("--test_batch_size", default=10240, type=int)
    parser.add_argument("--runs", default=100, type=int)

    args , _ = parser.parse_known_args()
    return args

if __name__ == '__main__':

    params = vars(get_params())
    tune_params = nni.get_next_parameter()
    params.update(tune_params)

    ktaus,wktaus,top5s,top10s,ref5s,ref10s,rank10s = [],[],[],[],[],[],[]
    train_times,eval_times = [],[]
    #seed
    for i in range(params['runs']):

        set_seed(i)
        torch.manual_seed(i)
        np.random.seed(i)

        #device
        device = torch.device(torch.device('cuda:'+str(params['gpu'])) if torch.cuda.is_available() else torch.device('cpu'))

        #dataset
        if  params['bench'] == "101":
            train_set = Nb101Dataset(split = params['train_split'],datatype='train',seed=i)
            test_set = Nb101Dataset(split= params['test_split'],datatype='test',seed=i+1)
            num_features = 5
            pre_file_path = 'rank_log/101/'
            file_path = pre_file_path + str(params['train_split']) + '_' + str(params['test_split']) +  '_' + params['loss'] + '.log'

        if params['bench'] == "201":
            train_set = Nb201Dataset(split=params['train_split'], data_type='train',data_set=params['dataset'])
            test_set = Nb201Dataset(split=params['test_split'], data_type='test',data_set=params['dataset'])
            num_features = 7
            pre_file_path = 'rank_log/201/'
            file_path = pre_file_path + str(params['train_split']) + '_' + str(params['test_split']) + '_' + params['dataset'] + '_' + params['loss'] + '.log'

        if params['bench'] == "TB101":
            train_set = Trans101Dataset(split=params['train_split'], data_type='train', task=params['task'])
            test_set = Trans101Dataset(split=params['test_split'], data_type='test', task=params['task'])
            num_features = 6
            pre_file_path = 'rank_log/tb101/micro/'
            file_path = pre_file_path + str(params['train_split']) + '_' + str(params['test_split']) + '_' + params['task'] + '_' + params['loss'] + '.log'

        if params['bench'] == "Macro":
            train_set = Trans101DatasetMacro(split=params['train_split'], data_type='train', task=params['task'])
            test_set = Trans101DatasetMacro(split=params['test_split'], data_type='test', task=params['task'])
            num_features = 6
            pre_file_path = 'rank_log/tb101/macro/'
            file_path = pre_file_path + str(params['train_split']) + '_' + str(params['test_split']) + '_' + params[
                'task'] + '_' + params['loss'] + '.log'

        if params['bench'] == "NLP":
            train_set = NBNLPDataset(split=params['train_split'], data_type='train', )
            test_set = NBNLPDataset(split=params['test_split'], data_type='test')
            num_features = 9
            pre_file_path = 'rank_log/nlp/'
            file_path = pre_file_path + str(params['train_split']) + '_' + str(params['test_split']) + '_' + params[
                'loss'] + '.log'

        train_loader = DataLoader(train_set, batch_size= params['train_batch_size'],
                                  num_workers=0, shuffle=True, drop_last=True)
        test_loader = DataLoader(test_set, batch_size= params['test_batch_size'], shuffle=False,
                                 num_workers=6)
        epochs =  int(params['epochs'])
        gnn_model = PlainGCN(num_features=num_features,num_classes=1,hidden=params["hidden_size"],num_fc_layers= params['layers']-1,num_conv_layers= params['layers'],dropout=params['dropout'],p_dim=params['p_dim']).to(device)

        model = gnn_model


        print("===run ", i, "===========")
        # loss,optimizer and lr_scheduler
        criterion = params['loss']
        optimizer = optim.Adam(model.parameters(),lr=params['lr'],weight_decay=params['wd'])
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,epochs)
        train_s_time = time.time()
        model = train(train_set,train_loader,model,optimizer,lr_scheduler,criterion,epochs,device)
        train_e_time = time.time()
        eval_s_time = time.time()
        ktau,predict_all,target_all,loss,mse,wktau,top5,top10,ref5,ref10,rank10 = evaluate(test_set,test_loader,model,criterion,device)
        eval_e_time = time.time()

        train_times.append(train_e_time-train_s_time)
        eval_times.append(eval_e_time-eval_s_time)


        print("=======val result===========")
        print("ktau:",ktau)
        with open(file_path, "a") as file:
            file.write("[run " +str(i)+"]" +"  ktau: "+str(ktau)+"\n")
        ktaus.append(ktau)
        wktaus.append(wktau)
        top5s.append(top5)
        top10s.append(top10)
        ref5s.append(ref5)
        ref10s.append(ref10)
        rank10s.append(rank10)

    print('ktau_mean:',np.array(ktaus).mean())
    print('wktau_mean:', np.array(wktaus).mean())
    print('top5_mean:', np.array(top5s).mean())
    print('top10_mean:', np.array(top10s).mean())
    print('ref5_mean:', np.array(ref5s).mean())
    print('ref10_mean:', np.array(ref10s).mean())
    print('rank10_mean:', np.array(rank10s).mean())
    print('train_time_mean:', np.array(train_times).mean())
    print('eval_time_mean:', np.array(eval_times).mean())

    with open(file_path, "a") as file:
        file.write('ktau_mean:'+str(np.array(ktaus).mean())+"\n")
        file.write('wktau_mean:' + str(np.array(wktaus).mean())+"\n")
        file.write('top5_mean:' + str(np.array(top5s).mean())+"\n")
        file.write('top10_mean:' + str(np.array(top10s).mean())+"\n")
        file.write('ref5_mean:' + str(np.array(ref5s).mean())+"\n")
        file.write('ref10_mean:' + str(np.array(ref10s).mean())+"\n")
        file.write('rank10_mean:' + str(np.array(rank10s).mean())+"\n")
        file.write('train_time_mean:' + str(np.array(train_times).mean()) + "\n")
        file.write('eval_time_mean:' + str(np.array(eval_times).mean()) + "\n")




















