# -*- coding: utf-8 -*-
"""loan_all_meta_final.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/14qOkh7vy2IllURmEnD49VXpOqugHziCz
"""

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import json
from train import *
from scipy.stats import sem
from scipy import mean, std
import argparse
from utils import prepare_data

np.random.seed(2)

parser = argparse.ArgumentParser()
parser.add_argument('--weight_decay', default=5e-4, type=float)
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--inner_lr', default=0.01, type=float)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--meta_step', default=80, type=int)

parser.add_argument('--test_size', default=0.2, type=float)
parser.add_argument('--meta_size', default=0.2, type=float) #for meta_weight_net
parser.add_argument('--batch_size', default=24, type=int)
parser.add_argument('--outer_batch_size', default=16, type=int)
parser.add_argument('--method', default='meta_balance', type=str)
parser.add_argument('--inner_sampling', default='Simple', type=str)
parser.add_argument('--outer_sampling', default='SMOTE', type=str)
parser.add_argument('--runs', default=1, type=int)

parser.add_argument('--loss_reweight_beta', default=0.9999, type=float)


args = parser.parse_args()
print("#####################################################")
print(args)

ROC_AUC = []
ROC_AUC_all = []

if args.method == 'meta_balance':

    inner_method = args.inner_sampling
    outer_method = args.outer_sampling

    print(inner_method, outer_method)
    train_loader, train_loader_outer, test_loader = prepare_data(inner_method, outer_method, args)
    All_Final = []
    for i in range (args.runs):
      np.random.seed(i)
      inner_lr, meta_batch_update_factor = args.inner_lr, args.meta_step
      roc_auc = train(inner_lr, meta_batch_update_factor, train_loader, train_loader_outer, test_loader, args)
      ra = roc_auc[len(roc_auc) - 1]
      print(ra, max(roc_auc))
      All_Final.append(ra)

    print(All_Final)
    print(mean(All_Final), sem(All_Final))


elif args.method == 'meta_balance_separate':

    method = args.inner_sampling
    #methods = 'Simple,SMOTE,BorderlineSMOTE,SVMSMOTE,ADASYN,RandomOverSampler,ClusterCentroids,RandomUnderSampler,NearMiss,AllKNN,SMOTEENN'.split(',')
    #for method in methods:
    print(method)
    train_loader, train_loader_outer, test_loader = prepare_data_separate(method, args)
    for i in range (args.runs):
      np.random.seed(i)
      inner_lr, meta_batch_update_factor = args.inner_lr, args.meta_step
      roc_auc = train_separate(inner_lr, meta_batch_update_factor, train_loader, train_loader_outer, test_loader, args)
      ra = roc_auc[len(roc_auc) - 1]
      print(ra, max(roc_auc))


elif args.method == 'meta_weight_net':
    train_loader, train_meta_loader, test_loader = prepare_data_meta_weight_net(args)

    All_Final = []

    for i in range (args.runs): #just the number of times you want to run the same thing
      roc_auc, roc_auc_meta = train_meta_weight_net(train_loader, train_meta_loader, test_loader, args)
      final = roc_auc[len(roc_auc)-1]
      max_i = roc_auc_meta.index(max(roc_auc_meta))
      print(final, max_i, roc_auc[max_i])
      All_Final.append(final)

    print(All_Final)
    print(mean(All_Final), sem(All_Final))



elif args.method == 'old_baselines':
    methods = args.inner_sampling.split(',')
    for method in methods:
        print(method)
        train_loader, test_loader = prepare_baseline(method, args)
        All_Final = []
        for i in range(args.runs):  # just the number of times you want to run the same thing
            roc_auc = train_baselines(train_loader, test_loader, args)
            ROC_AUC.append(roc_auc[len(roc_auc) - 1])
            print(roc_auc[len(roc_auc) - 1])
            All_Final.append(roc_auc[len(roc_auc) - 1])

        print(All_Final)
        print(mean(All_Final), sem(All_Final))

elif args.method == 'loss_reweight':

    train_loader, test_loader = prepare_baseline("Simple", args)
    All_Final = []
    for i in range(args.runs):  # just the number of times you want to run the same thing
        roc_auc = train_loss_reweight(train_loader, test_loader, args)
        ROC_AUC.append(roc_auc[len(roc_auc) - 1])
        print(roc_auc[len(roc_auc) - 1])
        All_Final.append(roc_auc[len(roc_auc) - 1])

    print(All_Final)
    print(mean(All_Final), sem(All_Final))