import os
import time
import random
import scipy.io as scio
import argparse
import copy
import torch
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import MultiStepLR
from utils.earlystopping import EarlyStopping
import higher
import optuna
from functools import partial
# from utils.cpudata import prepare_datasets, prepare_train_loaders
# from utils.utils_loss import jin_lossu, cour_lossu, jin_lossb, cour_lossb, mae_loss, mse_loss, cce_loss, gce_loss, phuber_ce_loss, focal_loss, pll_estimator
from utils.utils_loss import proden_loss, rc_loss, cc_loss, lws_loss, cavl_loss, d2cnn_loss
from utils.utils_algo import accuracy_check
# from cifar_models import densenet, resnet
from models_for_multibranch.resnet import resnet
from models_for_multibranch.linear import linear_model
# data loaders
from utils.cifar10 import load_cifar10
from utils.cifar100 import load_cifar100
from utils.tinyimagenet import load_tinyimagenet
from utils.realworld import load_realworld


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.	
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# + accuracy check for multi-branch models
def accuracy_check(loader, model, device):
    with torch.no_grad():
        total, num_samples = 0, 0
        for images, labels in loader:
            labels, images = labels.to(device), images.to(device)
            _, outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += (predicted == labels).sum().item()
            num_samples += labels.size(0)
    return 100*(total/num_samples)

# - accuracy check for multi-branch models

# + accuracy check for multi-branch models
def branch_accuracy_check(loader, model, model_b, j1, j2, train_Y, mask_j1, mask_j2, device):
    with torch.no_grad():
        classifier_prediction, j1_prediction, j2_prediction = \
            torch.zeros_like(train_Y, dtype=torch.long), torch.zeros_like(train_Y, dtype=torch.long), torch.zeros_like(train_Y, dtype=torch.long)
        for images, _, _, indexes in loader:
            X, Y = map(lambda x: x.to(device), (images, labels))
            feats, outputs = model(X)
            outputs_b = model_b(feats.clone().detach())
            classifier_prediction[indexes] = outputs.argmax(dim=1).cpu()
            j1_prediction[indexes] = outputs_b[j1].argmax(dim=1).cpu()
            j2_prediction[indexes] = outputs_b[j2].argmax(dim=1).cpu()
        classifier_results, j1_results, j2_results = \
            map(lambda x: (x == train_Y), (classifier_prediction, j1_prediction, j2_prediction))
        classifier_results_for_j2, j1_results_for_j2 = \
            map(lambda x: x[mask_j2 == 1], (classifier_results, j1_results))
        classifier_results_for_j1, j2_results_for_j1 = \
            map(lambda x: x[mask_j1 == 1], (classifier_results, j2_results))
        j2_rate_of_classifier, j2_rate_of_j1 = \
            map(lambda x: 100 * x.sum() / len(x), (classifier_results_for_j2, j1_results_for_j2))
        j1_rate_of_classifier, j1_rate_of_j2 = \
            map(lambda x: 100 * x.sum() / len(x), (classifier_results_for_j1, j2_results_for_j1))
    return j2_rate_of_classifier, j2_rate_of_j1, j1_rate_of_classifier, j1_rate_of_j2
# - accuracy check for multi-branch models

# + loss for multi-branch models
class multibranch_loss:
    def __init__(self, args, train_p_Y, num_branch, num_class, device):
        self.args = args
        self.conf = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.conf = self.conf.to(device)
        self.num_branch = num_branch
        self.num_class = num_class
        self.device = device
        # pseudo
        self.pseudo_matrix = train_p_Y / train_p_Y.sum(dim=1, keepdim=True)
        self.pseudo_matrix = self.pseudo_matrix.to(device)
        # merge
        self.merge_matrix = (self.args.merge_rate * self.conf + (1 - self.args.merge_rate) * self.pseudo_matrix).clone().detach()

    def compute_loss_for_classifier(self, output, indexes):
        target = self.conf[indexes].clone().detach()
        sm_output = F.softmax(output, dim=1)
        l = target * torch.log(sm_output)
        # l = target * torch.log_softmax(output, dim=1)
        loss = (- torch.sum(l)) / l.size(0)
        return loss
    
    def compute_loss_for_multibranch(self, output_b, indexes):
        k, n, c = output_b.size()
        target = self.conf[indexes].clone().detach()
        target_b = self.generate_target_for_multibranch(target)
        # target_b_matrix: (n*k, c), output_b_matrix: (n*k, c)
        target_b_matrix = target_b.contiguous().view((n * k, c)) 
        output_b_matrix = output_b.contiguous().view((n * k, c)) 
        # sm_output_b_matrix = F.softmax(output_b_matrix, dim=1)
        # l = target_b_matrix * torch.log(sm_output_b_matrix)
        l = target_b_matrix * torch.log_softmax(output_b_matrix, dim=1)
        loss = (- torch.sum(l)) / l.size(0)
        # if torch.isnan(target_b_matrix).any():print("target_b_matrix")
        # if torch.isnan(output_b_matrix).any():print("output_b_matrix")
        # if torch.isnan(sm_output_b_matrix).any():print("sm_output_b_matrix")
        # if torch.isnan(l).any():print("l")
        return loss
    
    def compute_loss_pseudolabel_for_classifier(self, output, indexes, **kwargs):
        target_pseudo = self.merge_matrix[indexes].clone().detach()
        sm_output = F.softmax(output, dim=1)
        l = target_pseudo * torch.log(sm_output)
        # l = target_pseudo * torch.log_softmax(output, dim=1)
        loss = ( - torch.sum(l)) / l.size(0)
        return loss

    def compute_meta_loss_pseudolabel_for_classifier(self, output, output_b, labels, meta_weight, **kwargs):
        target_pseudo = self.generate_pseudo_label_from_multibranch(output_b, labels, meta_weight)
        sm_output = F.softmax(output, dim=1)
        l = target_pseudo * torch.log(sm_output)
        # l = target_pseudo * torch.log_softmax(output, dim=1)
        loss = ( - torch.sum(l)) / l.size(0)
        return loss
    
    def update_conf(self, output1, indexes):
        target = self.conf[indexes].clone().detach()
        output = F.softmax(output1, dim=1)
        revisedY = target.clone()
        revisedY[revisedY > 0] = 1
        # revisedY = revisedY * (output.clone().detach())
        revisedY = revisedY * output
        revisedY = revisedY / (revisedY).sum(dim=1).repeat(revisedY.size(1), 1).transpose(0, 1)
        self.conf[indexes,:] = revisedY.clone().detach()

    def update_pseudo(self, output_b, labels, meta_weight, indexes, eps=1e-12):
        k, n, c = output_b.size()
        output_b_matrix = output_b.contiguous().view((k * n, c))
        sm_output_b_matrix = F.softmax(output_b_matrix, dim=1)
        sm_output_b = sm_output_b_matrix.contiguous().view((k, n, c))
        meta_weight = meta_weight.reshape((k, 1, n))
        meta_weight = meta_weight.permute(0, 2, 1)
        target_pseudo = (meta_weight * sm_output_b).sum(dim=0)
        target_pseudo = (target_pseudo + eps) * labels
        target_pseudo = target_pseudo / target_pseudo.sum(dim=1, keepdim=True)
        self.pseudo_matrix[indexes,:] = target_pseudo.clone().detach()
    
    def update_merge(self, indexes=None):
        if indexes is None:
            self.merge_matrix = (self.args.merge_rate * self.conf + (1 - self.args.merge_rate) * self.pseudo_matrix).clone().detach()
        else:
            self.merge_matrix[indexes] = (self.args.merge_rate * self.conf[indexes] + (1 - self.args.merge_rate) * self.pseudo_matrix[indexes]).clone().detach()
    
    def generate_pseudo_label_from_multibranch(self, output_b, labels, meta_weight, eps=1e-12):
        '''
            output_b: (k, n, c)
            target: (n, c)
            target_pseudo: (n, c)
            meta_weight(w): (k, n)
            \sum_i w_i = 1 
        '''
        k, n, c = output_b.size()
        output_b_matrix = output_b.contiguous().view((k * n, c))
        sm_output_b_matrix = F.softmax(output_b_matrix, dim=1)
        sm_output_b = sm_output_b_matrix.contiguous().view((k, n, c))
        meta_weight = meta_weight.reshape((k, 1, n))
        meta_weight = meta_weight.permute(0, 2, 1)
        target_pseudo = (meta_weight * sm_output_b).sum(dim=0)
        target_pseudo = (target_pseudo + eps) * labels
        target_pseudo = target_pseudo / target_pseudo.sum(dim=1, keepdim=True)
        return target_pseudo

    def generate_target_for_multibranch(self, target, eps=1e-12):
        '''
            target: (n, c)
            target_b: (k, n, c)
        '''
        # + for-list
        # def generate_target_b_per_branch(target, branch_id):
        #     # zero the branchid-th col
        #     temp = torch.ones_like(target)
        #     temp[:, branch_id] = 0
        #     target_b_per_branch = target * temp
        #     # normalize
        #     target_b_per_branch = target_b_per_branch / target_b_per_branch.sum(dim=1, keepdim=True)
        #     return target_b_per_branch
        
        # target_b_list = [ generate_target_b_per_branch(target, i) for i in range(0, self.num_branch)]
        # target_b = torch.cat(target_b_list, dim=0)
        # - for-list 

        assert self.num_branch == self.num_class
        n, c = target.size()
        target_b = target.contiguous().view((1, *(target.shape))).repeat(self.num_branch, 1, 1)
        temp_target_b = target_b.clone().detach()
        temp = torch.ones((self.num_branch, self.num_class), device=target.device) - \
                    torch.eye(self.num_branch, device=target.device)
        temp = temp.contiguous().view((self.num_branch, 1, self.num_class))
        temp = temp.expand_as(target_b)
        target_b = temp * target_b
        target_b_matrix = target_b.view((self.num_branch * n, c))
        target_b_matrix = target_b_matrix / target_b_matrix.sum(dim=1, keepdim=True)
        target_b = target_b_matrix.contiguous().view((self.num_branch, n, c))
        target_b[torch.isnan(target_b)] = temp_target_b[torch.isnan(target_b)]

        return target_b
# - loss for multi-branch models

# + function of uncertainty
def calculate_uncertainty(x, size=3):
    '''
        x: logits, 未经过softmax, (k, n, c)
    '''
    if size == 3:
        k, n, c = x.shape
        # sm_x_mat: (k * n, c)
        sm_x_mat = F.softmax(x.contiguous().view(k * n, c), dim=1)
        logsm_x_mat = torch.log_softmax(x.contiguous().view(k * n, c), dim=1)
        # uncertainty_arr: (k*n, )
        uncertainty_arr = (- sm_x_mat * logsm_x_mat).sum(dim=1)
        # uncertainty: (n, k)
        uncertainty = uncertainty_arr.view(k, n).T
    if size == 2:
        n, c = x.shape
        sm_x = F.softmax(x, dim=1)
        logsm_x = torch.log_softmax(x, dim=1)
        uncertainty = (- sm_x * logsm_x).sum(dim=1)

    return uncertainty

def calculate_margin(x, size=3):
    '''
        x: logits, 未经过softmax, (k, n, c)
    '''
    if size == 3:
        k, n, c = x.shape
        # sm_x_mat: (k * n, c)
        sm_x_mat = F.softmax(x.contiguous().view(k * n, c), dim=1)
        max_smax_value, _ = sm_x_mat.topk(k=2, dim=1)
        # uncertainty_arr: (k*n, )
        margin_arr = max_smax_value[:,0] - max_smax_value[:,1]
        # uncertainty: (n, k)
        margin = - margin_arr.view(k, n).T
    if size == 2:
        n, c = x.shape
        sm_x = F.softmax(x, dim=1)
        max_smax_value, _ = sm_x.topk(k=2, dim=1)
        margin_arr = max_smax_value[:,0] - max_smax_value[:,1]
        margin = - margin_arr
    return margin

def main(args):
    set_seed(args.seed)
    # prepare data_loader
    if args.ds == "cifar10":
        partial_train_loader, valid_loader, test_loader, dim, K, eval_train_loader, valdata, vallabels = load_cifar10(args.ds, batch_size=args.bs, device=device, has_eval_train_loader=True, has_meta_valid=True)
    if args.ds == "cifar100":
        partial_train_loader, valid_loader, test_loader, dim, K, eval_train_loader, valdata, vallabels = load_cifar100(args.ds, batch_size=args.bs, device=device, has_eval_train_loader=True, has_meta_valid=True)
    if args.ds == "tinyimagenet":
        partial_train_loader, valid_loader, test_loader, dim, K, eval_train_loader, valdata, vallabels = load_tinyimagenet(args.ds, batch_size=args.bs, device=device, has_eval_train_loader=True, has_meta_valid=True)
    if args.ds in ["lost", "birdac", "MSRCv2", "spd", "LYN"]:
        partial_train_loader, valid_loader, test_loader, dim, K, eval_train_loader, valdata, vallabels = load_realworld(args.ds, batch_size=args.bs, device=device, has_eval_train_loader=True, has_meta_valid=True, split_seed=args.split_seed)


    train_p_Y = torch.Tensor(partial_train_loader.dataset.given_label_matrix)
    train_Y = torch.Tensor(partial_train_loader.dataset.true_labels)

    if args.n_meta_interval == 0:
        args.n_meta_interval = len(partial_train_loader.dataset) // len(valid_loader.dataset)

    # prepare model
    if args.mo == 'resnet':
        model = resnet(depth=32, n_outputs=K)
        model.to(device)
        # + prepare model with multi branch
        model_dict = {
            'resnet': 64,
        }
        from multibranch_model import MBFAResNet
        model_b = MBFAResNet(depth=32, num_branch=K, num_class=K, dim_in=model_dict[args.mo], bias=True)
        model_b = model_b.to(device)
        # - prepare model with multi branch
    elif args.mo == 'linear':
        model = linear_model(input_dim=dim, output_dim=K)
        model.to(device)
        from multibranch_model import MBLayer
        model_b = MBLayer(num_branch=K, num_class=K, dim_in=dim, bias=True)
        model_b = model_b.to(device)

    if args.weight_mode == 'online':
        from online import OnlineSoftmaxWeightLayer
        model_m = OnlineSoftmaxWeightLayer(num_example=len(train_p_Y), num_branch=K)
        model_m = model_m.to(device)
    if args.weight_mode in ['uncertainty', 'margin']:
        from online import OnlineMLPNet
        model_m = OnlineMLPNet(input=K, hidden=100, output=K)
        model_m = model_m.to(device)
    if args.weight_mode == 'output':
        from online import OnlineMLPNet
        model_m = OnlineMLPNet(input=K*K, hidden=100, output=K)
        model_m = model_m.to(device)

    # define valid iter for meta learning
    valid_indices = [ i for i in range(len(valdata))]

    # prepare optimization criterion
    if args.lo == 'multibranch':
        train_p_Y = torch.Tensor(partial_train_loader.dataset.given_label_matrix)
        loss_fn = multibranch_loss(args, train_p_Y, K, K, device)

    optimizer_c = torch.optim.SGD(
        [
            {"params": model.parameters(), "lr": args.lr},
        ], weight_decay=args.wd, momentum=0.9)

    optimizer_b = torch.optim.SGD(model_b.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
    # optimizer_m = torch.optim.SGD(model_m.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9)
    optimizer_m = torch.optim.Adam(model_m.parameters(), lr=1e-3, weight_decay=args.wd)
    save_path = "checkpoints/{}/{}/".format(args.ds, args.lo)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    if args.use_earlystopping:
        early = EarlyStopping(patience=50, path=os.path.join(save_path, "{}_lo={}_seed={}.pt".format(args.ds, args.lo, args.seed)))


    valid_acc = accuracy_check(loader=valid_loader, model=model, device=device)
    test_acc  = accuracy_check(loader=test_loader,  model=model, device=device)

    print("Epoch {:>3d}, valid acc: {:.2f}, test acc: {:.2f}. ".format(0, valid_acc, test_acc))

    best_val, best_test, best_epoch = -1, -1, -1
    # torch.autograd.set_detect_anomaly(True)
    torch.set_printoptions(profile='full')

    for epoch in range(args.ep):
        model.train()
        t0 = time.time()
        for i, (images_list, labels, true_labels, indexes) in enumerate(partial_train_loader):
            X_list = list(map(lambda x: x.to(device), images_list))
            Y = labels.to(device)
            # training
            feats_list, outputs_list = zip(*list(map(lambda x: model(x), X_list)))
            # update conf
            outputs_b_list = list(map(lambda x: model_b(x), X_list))
            loss = sum(list(map(lambda x: loss_fn.compute_loss_for_classifier(x, indexes), outputs_list))) / len(outputs_list)
            loss_fn.update_conf(outputs_list[0], indexes)
            if args.use_branch:
                loss_branch = sum(list(map(lambda x: loss_fn.compute_loss_for_multibranch(x, indexes), outputs_b_list))) / len(outputs_list)
                optimizer_b.zero_grad()
                loss_branch.backward()
                optimizer_b.step()
                
                if epoch >= args.meta_start_epoch and i % (args.n_meta_interval) == 0:
                    with higher.innerloop_ctx(model, optimizer_c) as (model_virtual, optimizer_c_virtual):
                        # 用meta net输出的权重更新virtual model
                        for _ in range(args.n_inner_loop):
                            metaloop_X_list = X_list[:args.meta_aug_k]
                            metaloop_feats_list, metaloop_outputs_list = zip(*list(map(lambda x: model_virtual(x), metaloop_X_list)))

                            with torch.no_grad():
                                metaloop_outputs_b_list = list(map(lambda x: model_b(x), metaloop_X_list))
                            if args.weight_mode == 'online':
                                metaloop_weight = model_m(indexes)
                                metaloop_weight_list = [ metaloop_weight for _ in range(args.meta_aug_k)]
                            if args.weight_mode == 'uncertainty':
                                metaloop_uncertainty_b_list = list(map(lambda x: calculate_uncertainty(x.clone().detach()), metaloop_outputs_b_list))
                                metaloop_weight_list = list(map(lambda x: model_m(x), metaloop_uncertainty_b_list))
                            if args.weight_mode == 'margin':
                                metaloop_margin_b_list = list(map(lambda x: calculate_margin(x.clone().detach()), metaloop_outputs_b_list))
                                metaloop_weight_list = list(map(lambda x: model_m(x), metaloop_margin_b_list))
                            if args.weight_mode == 'output':
                                _metaloop_outputs_b_list = list(map(lambda x: x.permute(1, 0, 2).reshape(-1, K*K), metaloop_outputs_b_list))
                                metaloop_weight_list = list(map(lambda x: model_m(x), _metaloop_outputs_b_list))

                            metaloop_loss_reg = sum(list(map(lambda x, y, z: loss_fn.compute_meta_loss_pseudolabel_for_classifier(x, y, Y, z), 
                                                    metaloop_outputs_list, metaloop_outputs_b_list, metaloop_weight_list))) / args.meta_aug_k
                            optimizer_c_virtual.step(metaloop_loss_reg)

                        val_indexes = random.choices(valid_indices, k=args.bs)
                        X_val, Y_val = map(lambda x: x[val_indexes].to(device), (valdata, vallabels))
                        feats_val, outputs_val = model_virtual(X_val)
                        loss_meta = F.cross_entropy(outputs_val, Y_val)
                        optimizer_m.zero_grad()
                        loss_meta.backward()
                        optimizer_m.step()
                    
                    with torch.no_grad():
                        if args.weight_mode == 'online':
                            meta_weight = model_m(indexes)
                        if args.weight_mode == 'uncertainty':
                            uncertainty_b = calculate_uncertainty(outputs_b_list[0].clone().detach())
                            meta_weight = model_m(uncertainty_b)
                        if args.weight_mode == 'margin':
                            margin_b = calculate_margin(outputs_b_list[0].clone().detach())
                            meta_weight = model_m(margin_b)
                        if args.weight_mode == 'output':
                            _outputs_b = outputs_b_list[0].clone().detach().permute(1, 0, 2)
                            _outputs_b = _outputs_b.reshape(-1, K*K)
                            meta_weight = model_m(_outputs_b)

                    loss_fn.update_pseudo(metaloop_outputs_b_list[0], Y, meta_weight, indexes)

                    if epoch >= (args.meta_start_epoch + args.reg_delay_epoch):
                        loss_reg = sum(list(map(lambda x: loss_fn.compute_loss_pseudolabel_for_classifier(x, indexes), outputs_list)))  / len(outputs_list)
                        loss = loss + args.loss_reg_param * loss_reg

            # update classifier
            optimizer_c.zero_grad()
            loss.backward()
            optimizer_c.step()
            t3 = time.time()
            # uncertainty_arr[indexes] = calculate_uncertainty(outputs_list[0], size=2).clone().detach().cpu()
            # uncertainty_arr[indexes] = calculate_margin(outputs_list[0], size=2).clone().detach().cpu()
            # print("{} seconds.".format(t3-t2))
        t_e = time.time()
        print("Using {} seconds.".format(t_e-t0))
        model.eval()
        valid_acc = accuracy_check(loader=valid_loader, model=model, device=device)
        test_acc  = accuracy_check(loader=test_loader,  model=model, device=device)
        print("Epoch {:>3d}, valid acc: {:.2f}, test acc: {:.2f}. ".format(epoch+1, valid_acc, test_acc))
        if args.use_branch:
            result_c, result_b = map(lambda x: (x.argmax(dim=1) == train_Y.to(device)), (loss_fn.conf, loss_fn.pseudo_matrix))
            acc_c, acc_b = map(lambda x: 100 * x.sum() / len(x), (result_c, result_b))
            print("acc_c: {:.2f}, acc_b: {:.2f}.".format(acc_c, acc_b))
            merge_rate = 0.9
            result_m = (merge_rate * loss_fn.conf + (1 - merge_rate) * loss_fn.pseudo_matrix).argmax(dim=1) == train_Y.to(device)
            acc_m = 100 * result_m .sum() / len(result_m )
            print("acc m: {:.2f}.".format(acc_m))
            temp_num = (result_b.long() - result_b.long() * result_c.long()).sum()
            temp_rate = 100 * (result_b.long() - result_b.long() * result_c.long()).sum() / result_b.long().sum()
            print(temp_num, temp_rate)
            loss_fn.update_merge()

        if args.use_earlystopping:
            if epoch >= 1:
                early(valid_acc, model, epoch)
            if early.early_stop:
                break
        if valid_acc >= best_val:
            best_val = valid_acc
            best_epoch = epoch
            best_test = test_acc

    print("Best Epoch {:>3d}, Best valid acc: {:.2f}, test acc: {:.2f}. ".format(best_epoch, best_val, best_test))
    return best_test


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-lr', help='optimizer\'s learning rate', type=float, default=5e-2)
    parser.add_argument('-wd', help='weight decay', type=float, default=1e-3)
    parser.add_argument('-bs', help='batch size', type=int, default=256)
    parser.add_argument('-ep', help='number of epochs', type=int, default=250)
    parser.add_argument('-ds', help='specify a dataset', type=str, default='cifar10', required=False)
    parser.add_argument('-mo', help='model name', type=str, default='resnet', required=False)
    parser.add_argument('-lo', help='specify a loss function', default='multibranch', type=str, required=False)
    parser.add_argument('-seed', help='random seed', default=0, type=int)
    parser.add_argument('-split_seed', help='random seed', default=42, type=int)
    parser.add_argument('-gpu', type=str, default="0")
    # loss param
    parser.add_argument('--use_branch', type=float, default=1)
    parser.add_argument('--loss_reg_param', type=float, default=1)
    parser.add_argument('--merge_rate', type=float, default=0.9)
    # meta param
    parser.add_argument('--weight_mode', type=str, default='output')

    parser.add_argument('--meta_aug_k', type=int, default=1)
    parser.add_argument('--n_inner_loop', type=int, default=1)
    parser.add_argument('--n_meta_interval', type=int, default=4)
    parser.add_argument('--meta_start_epoch', type=int, default=0)
    parser.add_argument('--reg_delay_epoch', type=int, default=0)
    # early_stopping-on
    parser.add_argument('--use_earlystopping', help='', default=1, type=int)
    # recommend-on
    parser.add_argument('--use_recommend', type=int, default=1)
    # optuna-on
    parser.add_argument('--use_optuna', type=int, default=0)
    parser.add_argument('--optuna_aim', type=str, default="split_seed")
    parser.add_argument('--optuna_int_choice', type=str, default="split_seed")

    args = parser.parse_args()

    device = torch.device('cuda:' + args.gpu if torch.cuda.is_available() else 'cpu')
    main(args)