from libauc_dev.losses import CrossEntropyLoss, MIL_softmax_loss, MIL_attention_loss, TPMIL_softmax_loss, TPMIL_attention_loss, tpAUC_KL_Loss 
from losses import AUCMLoss, PSQLoss, PHLoss, PSHLoss, PLLoss, PSMLoss, PBHLoss
from libauc_dev.optimizers import MIDAM, SOTAs
from libauc.optimizers import PESG, Adam, SGD
from models import ResNet20, ResNet20_stoc_MIL, ResNet20_MIL, ResNet20_softmax, FFNN, FFNN_MIL, FFNN_stoc_MIL, FFNN_softmax
from imbalanced_sampler import imbalanced_sampler
from libauc_dev.utils import set_all_seeds, collate_fn, TabularDataset, instance_max_y, evaluate_tpauc, evaluate_tpauc, random_sample_y
from libauc_dev.metrics import auc_roc_score
from libauc_dev.datasets import BreastCancer, Colon, Oral, Lung

import torch 
import numpy as np
import time
from sklearn.model_selection import KFold
import tensorflow as tf
FLAGS = tf.compat.v1.flags.FLAGS
tf.compat.v1.flags.DEFINE_float('imbalanced_ratio', 0.5, 'controled imbalanced ratio for data')
tf.compat.v1.flags.DEFINE_string('loss', 'CEmean', 'loss functions.')
tf.compat.v1.flags.DEFINE_string('optimizer', 'Adam', 'Adam or Momentum')
tf.compat.v1.flags.DEFINE_string('dataset', 'MUSK1', 'cifar10 or cifar100')
tf.compat.v1.flags.DEFINE_float('gamma', 0, 'consecutive regularization')
tf.compat.v1.flags.DEFINE_float('lr', 0.01, 'learning rate')
tf.compat.v1.flags.DEFINE_float('momentum', 0.9, 'momentum for SGD')
tf.compat.v1.flags.DEFINE_float('decay', 1e-4, 'regularization weight decay')
tf.compat.v1.flags.DEFINE_float('moving_momentum', 0.0, 'momentum for moving average')
tf.compat.v1.flags.DEFINE_float('epsilon', 0.99, 'regularization weight decay')
tf.compat.v1.flags.DEFINE_float('eta', 1e-2, 'regularization weight decay')
tf.compat.v1.flags.DEFINE_string('activation', 'sigmoid', 'sigmoid, l2 or none')
tf.compat.v1.flags.DEFINE_float('tau', 0.1, 'softmax parameter for stoc-instance-max method')
tf.compat.v1.flags.DEFINE_integer('instance_batch_size', 1, 'sampling number from each bag')
tf.compat.v1.flags.DEFINE_integer('batch_size', 16, 'batch size for bag')
tf.compat.v1.flags.DEFINE_integer('seed', 123, 'sampling number from each bag')



def erm_loss_eval(loss1=0, loss2=0, loss_func='CSQLoss'):
    if loss_func == 'CSQLoss':
      return loss1 + loss2**2
    elif loss_func == 'CLLoss':
      return loss1 + np.log(1+np.exp(loss2))
    elif loss_func == 'CHLoss':
      return loss1 + max(loss2, 0)
    elif loss_func == 'CSHLoss' or loss_func == 'AUCMLoss':
      return loss1 + max(loss2, 0)**2

# paramaters
BATCH_SIZE = FLAGS.batch_size

imratio = FLAGS.imbalanced_ratio
pos_num = int(imratio*BATCH_SIZE)
lr = FLAGS.lr
weight_decay = FLAGS.decay
margin = 1.0
set_all_seeds(FLAGS.seed)

# dataloader
val_sample_ratio = 1.0
if FLAGS.dataset == 'BreastCancer':
  if FLAGS.loss in ['CELoss', 'AUCMLoss']:
    (train_data, train_labels), (test_data, test_labels) = BreastCancer(flag=True) # 3D data
  else:
    (train_data, train_labels), (test_data, test_labels) = BreastCancer(flag=False) # 2D data in bag
elif FLAGS.dataset == 'Colon':
  if FLAGS.loss in ['CELoss', 'AUCMLoss']:
    (train_data, train_labels), (test_data, test_labels) = Colon(flag=True) # 3D data
  else:
    (train_data, train_labels), (test_data, test_labels) = Colon(flag=False) # 2D data in bag
elif FLAGS.dataset == 'Oral':
  if FLAGS.loss in ['CELoss', 'AUCMLoss']:
    (train_data, train_labels), (test_data, test_labels) = Oral(flag=True) # 3D data
  else:
    (train_data, train_labels), (test_data, test_labels) = Oral(flag=False) # 2D data in bag
elif FLAGS.dataset == 'Lung':
  if FLAGS.loss in ['CELoss', 'AUCMLoss']:
    (train_data, train_labels), (test_data, test_labels) = Lung(flag=True) # 3D data
  else:
    (train_data, train_labels), (test_data, test_labels) = Lung(flag=False) # 2D data in bag
elif FLAGS.dataset in ['MUSK1','MUSK2','Fox','Tiger','Elephant','Bonds','Atoms','Chains']: 
  if FLAGS.dataset == 'MUSK1': 
    tmp = np.load('/your-path/data/musk_1.npz',allow_pickle=True)
  elif FLAGS.dataset == 'MUSK2': 
    tmp = np.load('/your-path/data/musk_2.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Fox': 
    tmp = np.load('/your-path/data/fox.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Tiger': 
    tmp = np.load('/your-path/data/tiger.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Elephant': 
    tmp = np.load('/your-path/data/elephant.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Bonds': 
    tmp = np.load('/your-path/data/bonds.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Atoms': 
    tmp = np.load('/your-path/data/atoms.npz',allow_pickle=True)
  elif FLAGS.dataset == 'Chains': 
    tmp = np.load('/your-path/data/chains.npz',allow_pickle=True)
  train_data = tmp['train_X']
  test_data = tmp['test_X']
  train_labels = tmp['train_Y'].astype(int)
  test_labels = tmp['test_Y'].astype(int)
  if FLAGS.seed != 123:
    teN = len(test_labels)
    trN = len(train_labels)
    randIds = np.random.permutation(trN+teN)
    X = np.concatenate([train_data,test_data],axis=0)
    Y = np.concatenate([train_labels,test_labels],axis=0)
    X = X[randIds]
    Y = Y[randIds]
    train_data = X[:trN]
    test_data = X[trN:]
    train_labels = Y[:trN]
    test_labels = Y[trN:]
  num_class=1

collate_function = None
if True:
  traindSet = TabularDataset(train_data, train_labels)
  testSet = TabularDataset(test_data, test_labels)
  collate_function = collate_fn
  DIMS_dict={'MUSK1':166, 'MUSK2':166, 'Fox':230, 'Tiger':230, 'Elephant':230, 'Bonds':16, 'Chains':24, 'Atoms':10, 'PDGM':155, 'BreastCancer': 672, 'Colon':256, 'Oral':256, 'Lung':256, 'hypertension': 31}
  DIMS = DIMS_dict[FLAGS.dataset]
  if FLAGS.dataset in ['PDGM','hypertension']:
    inchannels = 1
    if FLAGS.loss in ['CELoss','AUCMLoss']: # 3D model
      inchannels = DIMS
  elif FLAGS.dataset in ['BreastCancer','Colon','Oral','Lung']:
    inchannels = 3
  



# You need to include sigmoid activation in the last layer for any customized models!
kf = KFold(n_splits=5)
N = len(traindSet)
print(N)
tmpX = np.zeros((N,1))
best_val_auc = 0
best_para = 1.0
if FLAGS.activation == 'sigmoid' or FLAGS.activation == 'l2' or FLAGS.activation == 'l1' or FLAGS.activation == 'scale':
   parameter_set = [0.1, 0.5, 1.0]
else:
  parameter_set = [0.1, 1, 10]
if FLAGS.loss == 'PBHLoss':
  scale_set = [2.0, 10.0, 100.0]
  parameter_set = [(r,b) for r in parameter_set for b in scale_set] 
if FLAGS.loss in ['PLLoss','PSMLoss','SOTAs-smx','SOTAs-att']:
  parameter_set = [0.1, 1.0, 10.0]
if FLAGS.loss in ['TPAUCmax-s', 'TPAUCatt-s']:
  parameter_set = [0.9, 0.5, 0.1]
if FLAGS.loss in ['CELoss','CEmax','CEatt','CEmean','CEsoftmax']:
  parameter_set = [0]
  FLAGS.activation = None
  print(FLAGS.activation)

testloader =  torch.utils.data.DataLoader(testSet, batch_size=1, num_workers=1, shuffle=False, collate_fn=collate_fn)


part = 0

print ('Start Training')
print ('-'*30)
for train_id, val_id in kf.split(tmpX):
  for para in parameter_set:
    trainloader =  torch.utils.data.DataLoader(dataset=traindSet, sampler=imbalanced_sampler(data_source=traindSet, batch_size=BATCH_SIZE,imratio=imratio,idx=train_id), batch_size=BATCH_SIZE, num_workers=1, shuffle=False, collate_fn=collate_fn)
    validloader =  torch.utils.data.DataLoader(dataset=traindSet, sampler=imbalanced_sampler(data_source=traindSet, batch_size=2,imratio=None,idx=val_id,sample_scale=val_sample_ratio), batch_size=1, num_workers=1, shuffle=False, collate_fn=collate_fn)
    # load pretrained model
    if FLAGS.dataset in ['MUSK1','MUSK2','Fox','Tiger','Elephant','Bonds','Atoms','Chains']:
      if FLAGS.loss in ['CEatt','AUCMatt','AUCMatt-det','SOTAs-att']:
        model = FFNN_MIL(num_classes=num_class, last_activation=FLAGS.activation, dims=DIMS)
        print('Attention-based MIL-FFNN model')
      elif FLAGS.loss in ['AUCMatt-s','TPAUCatt-s']:
        model = FFNN_stoc_MIL(num_classes=num_class, last_activation=None, dims=DIMS)
        print('Stochastic-attention-based MIL-FFNN model')
      elif FLAGS.loss in ['AUCMsoftmax', 'CEsoftmax', 'SOTAs-smx']:
        model = FFNN_softmax(num_classes=num_class, last_activation=FLAGS.activation, dims=DIMS, tau=FLAGS.tau)
        print('Softmax-based FFNN model')
      else:
        model = FFNN(num_classes=num_class, last_activation=FLAGS.activation, dims=DIMS)
      model = model.cuda()
    else:
      if FLAGS.loss in ['CEatt','AUCMatt','AUCMatt-det','SOTAs-att']:
        model = ResNet20_MIL(num_classes=1, last_activation=FLAGS.activation, inchannels=inchannels)
        print('Attention-based MIL-FFNN model')
      elif FLAGS.loss in ['AUCMatt-s','TPAUCatt-s']:
        model = ResNet20_stoc_MIL(num_classes=1, last_activation=None, inchannels=inchannels)
        print('Stochastic-attention-based MIL-FFNN model')
      elif FLAGS.loss in ['AUCMsoftmax', 'CEsoftmax', 'SOTAs-smx']:
        model = ResNet20_softmax(num_classes=1, last_activation=FLAGS.activation, inchannels=inchannels, tau=FLAGS.tau)
        print('Softmax-based FFNN model')
      else:
        model = ResNet20(num_classes=1, last_activation=FLAGS.activation, inchannels=inchannels)
      model = model.cuda()
      
    # define loss & optimizer
    if FLAGS.loss == 'PSQLoss':
      Loss = PSQLoss(margin=para)
    elif FLAGS.loss == 'PHLoss':
      Loss = PHLoss(margin=para)
    elif FLAGS.loss == 'PSHLoss':
      Loss = PSHLoss(margin=para)
    elif FLAGS.loss in ['AUCMLoss', 'AUCMmean', 'AUCMmax', 'AUCMatt', 'AUCMsoftmax']:
      Loss = AUCMLoss(margin=para)
    elif FLAGS.loss in ['AUCMatt-s']:
      Loss = MIL_attention_loss(data_length=N, threshold=para)
    elif FLAGS.loss in ['AUCMmax-s']:
      Loss = MIL_softmax_loss(data_length=N, threshold=para, tau=FLAGS.tau)
    elif FLAGS.loss in ['TPAUCmax-s']:
      Loss = TPMIL_softmax_loss(data_length=N, rate=para, momentum=FLAGS.moving_momentum)
    elif FLAGS.loss in ['TPAUCatt-s']:
      Loss = TPMIL_attention_loss(data_length=N, rate=para, momentum=FLAGS.moving_momentum)
    elif FLAGS.loss in ['SOTAs-smx', 'SOTAs-att']:
      Loss = tpAUC_KL_Loss(pos_len=N, Lambda=para, tau=para)
    elif FLAGS.loss in ['AUCMatt-det']:
      Loss = AUCM_loss(threshold=para)
    elif FLAGS.loss == 'PBHLoss':
      r, b = para
      Loss = PBHLoss(r=r, b=b)
    elif FLAGS.loss == 'PLLoss':
      Loss = PLLoss(margin=para)
    elif FLAGS.loss == 'PSMLoss':
      Loss = PSMLoss(margin=para)
    elif FLAGS.loss in ['CELoss','CEmax','CEatt','CEmean','CEsoftmax']:
      Loss = CrossEntropyLoss()
    elif FLAGS.loss == 'BALoss':
      Loss = bag_AUC_loss(data_length = N, threshold = para, epsilon=FLAGS.epsilon, tau_1=FLAGS.epsilon, tau_2=FLAGS.epsilon, eta_0=FLAGS.eta)
    if FLAGS.loss in ['AUCMLoss', 'AUCMmean', 'AUCMmax', 'AUCMatt', 'AUCMsoftmax']:
      optimizer = PESG(model, 
                       a=Loss.a, 
                       b=Loss.b, 
                       alpha=Loss.alpha, 
                       imratio=imratio, 
                       lr=lr,
                       gamma=FLAGS.gamma, 
                       margin=para, 
                       weight_decay=weight_decay)
    elif FLAGS.optimizer == 'Adam':
      optimizer = Adam(model, lr=lr, weight_decay=weight_decay, gamma=FLAGS.gamma)  
    elif FLAGS.optimizer == 'Momentum':
      optimizer = SGD(model, lr=lr, weight_decay=weight_decay, momentum=FLAGS.momentum) 
    if FLAGS.loss in ['AUCMatt-s','AUCMmax-s','AUCMatt-det']:
      optimizer = MIDAM(model, a=Loss.a, b=Loss.b, alpha=Loss.alpha, lr=lr, weight_decay=weight_decay, momentum=FLAGS.momentum)  
    if FLAGS.loss in ['SOTAs-smx','SOTAs-att']:
      optimizer = SOTAs(model, loss_fn=Loss, lr=lr, weight_decay=weight_decay, momentum=FLAGS.momentum)  
    print('Margin=%s, part=%s'%(para, part))
    for epoch in range(100):
      tr_loss = 0
      tr_loss_1 = 0
      tr_loss_2 = 0
      if epoch in [50,75]:
          if FLAGS.loss in ['SOTAs-smx','SOTAs-att']:
            optimizer.update_lr(decay_factor=10, coef_decay_factor=2)
          else:
            optimizer.update_stepsize(decay_factor=10)
          if FLAGS.loss in ['AUCMatt-s','AUCMmax-s','TPAUCmax-s','TPAUCatt-s']:
            Loss.update_smoothing(decay_factor=2)
      start_time = time.process_time()
      for idx, data in enumerate(trainloader):
          if True:
            train_data_bags, train_labels, ids = data
            train_data = []
            if True: # uniform random sampling
              y_pred = []
              sd = []
              for i in range(len(ids)):
                if FLAGS.loss in ['CELoss', 'CEmean', 'CEatt', 'AUCMLoss', 'AUCMmean', 'AUCMatt', 'AUCMsoftmax', 'CEsoftmax', 'AUCMatt-det', 'SOTAs-smx', 'SOTAs-att']:
                  tmp_pred = random_sample_y(train_data_bags[i], model, batch_size=FLAGS.instance_batch_size, mode='plain') # when model is attention, it is auto-reduce
                  y_pred.append(tmp_pred)
                elif FLAGS.loss in ['AUCMmax','CEmax']:
                  tmp_pred = random_sample_y(train_data_bags[i], model, batch_size=FLAGS.instance_batch_size, mode='max') # when model is attention, it is auto-reduce
                  y_pred.append(tmp_pred)
                elif FLAGS.loss in ['AUCMmax-s','TPAUCmax-s']:
                  tmp_pred = random_sample_y(train_data_bags[i], model, batch_size=FLAGS.instance_batch_size, mode='exp') # when model is attention, it is auto-reduce
                  y_pred.append(tmp_pred)
                elif FLAGS.loss in ['AUCMatt-s', 'TPAUCatt-s']:
                  tmp_pred, tmp_sd = random_sample_y(train_data_bags[i], model, batch_size=FLAGS.instance_batch_size, mode='att') # when model is attention, it is auto-reduce
                  y_pred.append(tmp_pred)
                  sd.append(tmp_sd)
              y_pred = torch.cat(y_pred, dim=0) 
              if FLAGS.loss in ['AUCMatt-s', 'TPAUCatt-s']:
                sd = torch.cat(sd, dim=0) 
            
            ids = torch.from_numpy(np.array(ids)).cuda()
            train_labels = torch.from_numpy(np.array(train_labels)).cuda()
          if FLAGS.loss in ['AUCMLoss', 'AUCMmean', 'AUCMmax', 'AUCMatt', 'AUCMsoftmax']:
            loss, real_loss_1, real_loss_2 = Loss(y_pred, train_labels)
            tr_loss_1 = tr_loss_1  + real_loss_1.cpu().detach().numpy()
            tr_loss_2 = tr_loss_2  + real_loss_2.cpu().detach().numpy()
          else: #if FLAGS.loss in ['CELoss', 'CEmean', 'CEmax', 'CEatt', 'AUCMatt-s', 'AUCMmax-s', 'TPAUCmax-s', 'TPAUCatt-s', 'CEsoftmax', 'AUCMatt-det']: # AUCMatt-det is deprecated
            if FLAGS.loss in ['AUCMatt-s', 'TPAUCatt-s']:
              loss = Loss(y_pred, sd, train_labels, ids)
            elif FLAGS.loss in [ 'AUCMmax-s', 'TPAUCmax-s']:
              loss = Loss(y_pred, train_labels, ids)
            elif FLAGS.loss in [ 'SOTAs-smx', 'SOTAs-att']:
              loss = Loss(y_pred, train_labels, ids[:pos_num])
            else:
              loss = Loss(y_pred, train_labels).mean()
            tr_loss = tr_loss  + loss.cpu().detach().numpy()
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

      end_time = time.process_time()
      dur_time = end_time - start_time
      tr_loss = tr_loss/idx
      if FLAGS.loss in ['AUCMLoss']:#, 'AUCMmean', 'AUCMmax', 'AUCMatt']:
          tr_loss_1 = tr_loss_1/idx
          tr_loss_2 = tr_loss_2/idx
          tr_loss = erm_loss_eval(tr_loss_1, tr_loss_2, FLAGS.loss)
      model.eval()
      print ('Epoch=%s, BatchID=%s, training_loss=%.4f, lr=%.4f'%(epoch, idx, tr_loss,  optimizer.lr))
      print ('Epoch=%s, BatchID=%s, time=%.4f'%(epoch, idx, dur_time))
      if FLAGS.loss in ['AUCMLoss', 'AUCMmean', 'AUCMmax', 'AUCMatt', 'AUCMatt-s', 'AUCMmax-s', 'AUCMsoftmax', 'AUCMatt-det']:
        print (str(Loss.a.cpu().detach().numpy())+', '+str(Loss.b.cpu().detach().numpy())+ ', '+str(Loss.alpha.cpu().detach().numpy()))
      tr_loss = 0
      with torch.no_grad():
        if FLAGS.loss in ['CEloss', 'CEmean', 'AUCMLoss', 'AUCMmean', 'AUCMatt', 'CEatt', 'CEsoftmax', 'AUCMsoftmax', 'SOTAs-smx', 'SOTAs-att']:
          tr_tpauc = evaluate_tpauc(trainloader, model, mode='mean') 
          te_tpauc = evaluate_tpauc(testloader, model, mode='mean') 
          val_tpauc = evaluate_tpauc(validloader, model, mode='mean') 
        elif FLAGS.loss in ['AUCMatt-s', 'TPAUCatt-s']:
          tr_tpauc = evaluate_tpauc(trainloader, model, mode='att') 
          te_tpauc = evaluate_tpauc(testloader, model, mode='att') 
          val_tpauc = evaluate_tpauc(validloader, model, mode='att') 
        elif FLAGS.loss in ['AUCMmax-s', 'TPAUCmax-s']:
          tr_tpauc = evaluate_tpauc(trainloader, model, mode='softmax', tau=FLAGS.tau) 
          te_tpauc = evaluate_tpauc(testloader, model, mode='softmax', tau=FLAGS.tau) 
          val_tpauc = evaluate_tpauc(validloader, model, mode='softmax', tau=FLAGS.tau) 
        else:
          tr_tpauc = evaluate_tpauc(trainloader, model, mode='max') 
          te_tpauc = evaluate_tpauc(testloader, model, mode='max') 
          val_tpauc = evaluate_tpauc(validloader, model, mode='max') 

        model.train()
        print('Epoch=%s, BatchID=%s, TP_Tr_AUC(0.1)=%.4f, TP_Val_AUC(0.1)=%.4f, TP_Test_AUC(0.1)=%.4f, TP_Tr_AUC(0.3)=%.4f, TP_Val_AUC(0.3)=%.4f, TP_Test_AUC(0.3)=%.4f, TP_Tr_AUC(0.5)=%.4f,TP_Val_AUC(0.5)=%.4f, TP_Test_AUC(0.5)=%.4f, TP_Tr_AUC(0.7)=%.4f, TP_Val_AUC(0.7)=%.4f, TP_Test_AUC(0.7)=%.4f, TP_Tr_AUC(0.9)=%.4f, TP_Val_AUC(0.9)=%.4f, TP_Test_AUC(0.9)=%.4f, lr=%.4f \n'%(epoch, idx, tr_tpauc[0], val_tpauc[0], te_tpauc[0], tr_tpauc[1], val_tpauc[1], te_tpauc[1], tr_tpauc[2], val_tpauc[2], te_tpauc[2], tr_tpauc[3], val_tpauc[3], te_tpauc[3], tr_tpauc[4], val_tpauc[4], te_tpauc[4], optimizer.lr))
  part += 1 


