import argparse

import numpy as np
from sklearn.metrics import accuracy_score, recall_score, roc_curve

import torch
import os
os.chdir("/Users/metailp")
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from rtpt import RTPT
import random
from nsfr_utils import denormalize_kandinsky, get_data_loader, get_data_pos_loader, get_prob, get_nsfr_model, get_nsfr_model_mi, get_nsfr_model_mi_train,update_initial_clauses
from nsfr_utils import save_images_with_captions, to_plot_images_kandinsky, generate_captions
from logic_utils import get_lang,get_lang_mi,get_lang_mi_train, get_searched_clauses
from mode_declaration import get_mode_declarations

from clause_generator import ClauseGenerator


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch-size", type=int, default=24,
                        help="Batch size to infer with")
    parser.add_argument("--batch-size-bs", type=int,
                        default=1, help="Batch size in beam search")
    parser.add_argument("--e", type=int, default=6,
                        help="The maximum number of objects in one image")
    parser.add_argument("--dataset", default="twopairs",choices=["twopairs", "threepairs", "red-triangle", "closeby",
                                              "online", "online-pair", "nine-circles", "clevr-hans0", "clevr-hans1", "clevr-hans2"], help="Use kandinsky patterns dataset")
    parser.add_argument("--dataset-type", default="kandinsky",
                        help="kandinsky or clevr")
    parser.add_argument('--device', default='cpu',
                        help='cuda device, i.e. 0 or cpu')
    parser.add_argument("--no-cuda",default=True, action="store_true",
                        help="Run on CPU instead of GPU (not recommended)")
    parser.add_argument("--small-data", action="store_true",
                        help="Use small training data.")
    parser.add_argument("--no-xil", action="store_true",
                        help="Do not use confounding labels for clevr-hans.")
    parser.add_argument("--num-workers", type=int, default=4,
                        help="Number of threads for data loader")
    parser.add_argument('--gamma', default=0.01, type=float,
                        help='Smooth parameter in the softor function')
    parser.add_argument("--plot", action="store_true",
                        help="Plot images with captions.")
    parser.add_argument("--t-beam", type=int, default=4,
                        help="Number of rule expantion of clause generation.")
    parser.add_argument("--n-beam", type=int, default=4,
                        help="The size of the beam.")
    parser.add_argument("--n-max", type=int, default=50,
                        help="The maximum number of clauses.")
    parser.add_argument("--m", type=int, default=3,
                        help="The size of the logic program.")
    parser.add_argument("--n-obj", type=int, default=2,
                        help="The number of objects to be focused.")
    parser.add_argument("--epochs", type=int, default=101,
                        help="The number of epochs.")
    parser.add_argument("--lr", type=float, default=1e-3,
                        help="The learning rate.")
    parser.add_argument("--n-data", type=float, default=200,
                        help="The number of data to be used.")
    parser.add_argument("--pre-searched", action="store_true",
                        help="Using pre searched clauses.")
    args = parser.parse_args()
    return args

# def get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False):


def discretise_NSFR(NSFR, args, device):
    lark_path = 'src/lark/exp.lark'
    lang_base_path = 'data/lang/'
    lang, clauses_, bk_clauses, bk, atoms = get_lang(
        lark_path, lang_base_path, args.dataset_type, args.dataset)
    # Discretise NSFR rules
    clauses = NSFR.get_clauses()
    return get_nsfr_model(args, lang, clauses, atoms, bk, bk_clauses, device, train=False)


def predict(NSFR, loader, args, device,  th=None, split='train'):
    predicted_list = []
    target_list = []
    count = 0
    ###NSFR = discretise_NSFR(NSFR, args, device)
    # NSFR.print_program()

    for i, sample in tqdm(enumerate(loader, start=0)):
        # to cuda
        imgs, target_set = map(lambda x: x.to(device), sample)

        # infer and predict the target probability
        V_T = NSFR(imgs)
        predicted = get_prob(V_T, NSFR, args)
        predicted_list.append(predicted.detach())
        target_list.append(target_set.detach())
        if args.plot:
            imgs = to_plot_images_kandinsky(imgs)
            captions = generate_captions(
                V_T, NSFR.atoms, NSFR.pm.e, th=0.3)
            save_images_with_captions(
                imgs, captions, folder='result/kandinsky/' + args.dataset + '/' + split + '/', img_id_start=count, dataset=args.dataset)
        count += V_T.size(0)  # batch size

    predicted = torch.cat(predicted_list, dim=0).detach().cpu().numpy()
    target_set = torch.cat(target_list, dim=0).to(
        torch.int64).detach().cpu().numpy()

    if th == None:
        fpr, tpr, thresholds = roc_curve(target_set, predicted, pos_label=1)
        accuracy_scores = []
        print('ths', thresholds)
        for thresh in thresholds:
            accuracy_scores.append(accuracy_score(
                target_set, [m > thresh for m in predicted]))

        accuracies = np.array(accuracy_scores)
        max_accuracy = accuracies.max()
        max_accuracy_threshold = thresholds[accuracies.argmax()]
        rec_score = recall_score(
            target_set,  [m > thresh for m in predicted], average=None)

        print('target_set: ', target_set, target_set.shape)
        print('predicted: ', predicted, predicted.shape)
        print('accuracy: ', max_accuracy)
        print('threshold: ', max_accuracy_threshold)
        print('recall: ', rec_score)

        return max_accuracy, rec_score, max_accuracy_threshold
    else:
        accuracy = accuracy_score(target_set, [m > th for m in predicted])
        rec_score = recall_score(
            target_set,  [m > th for m in predicted], average=None)
        return accuracy, rec_score, th

def binary_cross_entropy(predictions, targets):
    epsilon = 1e-15
    predictions = np.clip(predictions, epsilon, 1 - epsilon)

    bce_loss = - (targets * np.log(predictions) + (1 - targets) * np.log(1 - predictions))

    mean_bce_loss = np.mean(bce_loss)
    #print("Binary Cross Entropy Loss:", mean_bce_loss)
    return mean_bce_loss

def mask_for_gradient(softmax_weight, mask, mask_index):
        for row in range(len(softmax_weight)):
            for column in range(len(softmax_weight[0])):
                if softmax_weight[row][column] > 0.99 and (column not in mask_index):
                    mask[row] = 0
                    mask_index.append(column)
        return mask

def train_nsfr(args, NSFR,  train_loader, val_loader, test_loader, device, writer, lark_path, lang_base_path, lang, clauses, atoms, bk, bk2, rtpt):

    bce = torch.nn.BCELoss()
    loss_list = []
    for epoch in range(args.epochs):
        loss_i = 0
        for i, sample in tqdm(enumerate(train_loader, start=0)):
            # to cuda
            loss_iteration=[ [] for _ in range(1) ]
            program = [ [] for _ in range(1) ]
            task1_test_loss = [ [] for _ in range(3) ]
            task2_test_loss = [ [] for _ in range(3) ]
            task3_test_loss = [ [] for _ in range(3) ]
            for iteration in range (1):
                imgs, target_set = map(lambda x: x.to(device), sample)
                # infer and predict the target probability
                V_T = NSFR(imgs)
                lang_mi_train, clauses_mi_train, atoms_mi_train, terms_mi_train = get_lang_mi_train(lark_path, lang_base_path,
                                                                                    args.dataset_type, args.dataset,
                                                                                    atoms, V_T[0], bk)
                NSFR_mi_train = get_nsfr_model_mi_train(args, lang, clauses, atoms, bk, device, atoms_mi_train,
                                                        clauses_mi_train, lang_mi_train, terms_mi_train, V_T[0], train=True)


                lang_mi, clauses_mi, atoms_mi, terms_mi = get_lang_mi_train(lark_path, lang_base_path,
                                                                                    args.dataset_type, args.dataset,
                                                                                    atoms, V_T[0], bk2, meta_arg=1)
                NSFR_mi_train2 = get_nsfr_model_mi_train(args, lang, clauses, atoms, bk2, device, atoms_mi,
                                                        clauses_mi, lang_mi, terms_mi, V_T[0], meta_arg=1, train=True )

                #V_T_mi_train = NSFR_mi_train(imgs[0])
                #V_T_mi_train2 = NSFR_mi_train2(imgs[0])

               # a = np.where(V_T_mi > 0.91)[1]
               # b = np.array(self.atoms_mi)
               # c = b[a]

                # construct meta modell
                #meta_argument= 1
                #lang_mi, clauses_mi, atoms_mi = get_lang_mi(lark_path, lang_base_path, args.dataset_type, args.dataset,
                #                                            atoms, V_T[0], meta_argument)
                #NSFR_mi_train2 = get_nsfr_model_mi(args, lang, clauses, atoms, bk2, device, atoms_mi, clauses_mi, lang_mi, V_T[0],
                #                            train=True)

                # construct meta modell 2
            #    meta_argument = 2
            #    lang_mi2, clauses_mi2, atoms_mi2 = get_lang_mi(lark_path, lang_base_path, args.dataset_type, args.dataset,
            #                                                atoms, V_T[0], meta_argument)
            #    NSFR_mi2 = get_nsfr_model_mi(args, lang, clauses, atoms, bk, device, atoms_mi2, clauses_mi2, lang_mi2, V_T[0], train=False)

                for i in range(1100):
                    print(i)
                    if i<=600:# and i < 150:

                        target = torch.tensor(1.0, dtype=torch.float32)
                        'select the target atom'
                        loss_i = 0

                        V_T_mi_train = NSFR_mi_train2(imgs[0])
                        #atoms_mi_train
                        params = NSFR_mi_train2.get_params()
                        "parameters of meta interpreter that needs to be optimized"
                        #optimizer = torch.optim.Adam(params, lr=args.lr)
                        optimizer = torch.optim.RMSprop(params, lr=args.lr)
                        target_str='plan(a,h)'
                        atoms_all_string= [str(atom) for atom in atoms_mi]
                        for index, atom in enumerate(atoms_all_string):
                            if atom == target_str:
                                target_index = index
                        predict = V_T_mi_train[0][target_index]

                        print('grad is', params[0].grad)
                        loss = bce(predict, target)
#                        print(binary_cross_entropy(predict, target))
                        #loss_i += loss.item()
                        # print('loss is ',loss)
                        print('====================================')
                        optimizer.zero_grad()
                        loss.backward()
                        print('grad is', params[0].grad)
                        optimizer.step()

                        print("target: ", target_set.detach().cpu().numpy())
                        print('predict', predict )
                        print('weight is', params)
                        print('weight is', params)
                        NSFR_mi_train2.print_program()
                        for program_fraction in range(len(NSFR_mi_train2.store_program())):
                            program[iteration].append(NSFR_mi_train2.store_program()[program_fraction])
                        print("loss: ", loss.item())
                        loss_iteration[iteration].append(loss.item())


                    if i>600:
                        target = torch.tensor(1.0, dtype=torch.float32)
                        'select the target atom'
                        loss_i = 0
                        V_T_mi_train = NSFR_mi_train(imgs[0])
                        # atoms_mi_train
                        params = NSFR_mi_train.get_params()
                        "parameters of meta interpreter that needs to be optimized"
                        # optimizer = torch.optim.Adam(params, lr=args.lr)
                        optimizer = torch.optim.RMSprop(params, lr=args.lr)
                        target_str = 'plan(a,e)'
                        atoms_all_string = [str(atom) for atom in atoms_mi_train]
                        for index, atom in enumerate(atoms_all_string):
                            if atom == target_str:
                                target_index = index
                        predict = V_T_mi_train[0][target_index]

                        print('grad is', params[0].grad)
                        loss = bce(predict, target)
                        print('====================================')
                        optimizer.zero_grad()
                        loss.backward()
                        print('grad is', params[0].grad)
                        optimizer.step()

                        print("target: ", target_set.detach().cpu().numpy())
                        print('predict', predict)
                        print('weight is', params)
                        print('weight is', params)
                        NSFR_mi_train.print_program()
                        for program_fraction in range(len(NSFR_mi_train.store_program())):
                            program[iteration].append(NSFR_mi_train.store_program()[program_fraction])
                        print("loss: ", loss.item())
                        loss_iteration[iteration].append(loss.item())





                    if 0:
                        if i > 200 and i <= 400:
                            target_all = [ 0, 0, 1, 1, 0, 0, 0]
                            target_str_all = [target_atom_str, target_atom_str1, target_atom_str2, target_atom_str3,
                                              target_atom_str4, target_atom_str5, target_atom_str6]
                            #target_str_all = [ target_atom_str2, target_atom_str3]#, target_atom_str4, target_atom_str5, target_atom_str6]
                            'use the target reasoner to get the values'
                            # generate
                            # random_numbers = [random.randint(0, 6) for _ in range(4)]
                            '''
                            loss = meta_train(target_all,target_str_all,NSFR_mi_train,imgs):
                            '''
                            random_numbers = random.sample(range(7), 5)
                            target_str = [target_str_all[random_numbers[0]], target_str_all[random_numbers[1]], target_str_all[random_numbers[2]], target_str_all[random_numbers[3]], target_str_all[random_numbers[4]]]                   # 输出随机数列表

                            target = torch.ones(len(target_str))
                            predict = torch.ones(len(target_str))
                            for index in range(len(target_str)):
                                target[index] = target_all[random_numbers[index]]

                            print(target)

                            'select the target atom'
                            # for a in range(2):
                            loss_i = 0
                            # if a==0:
                            # print(a)
                            V_T_mi_train = NSFR_mi_train(imgs[0])
                            # for a in range(2):
                            #    if a == 0:
                            params = NSFR_mi_train.get_params()
                            "parameters of meta interpreter that needs to be optimized"
                            optimizer = torch.optim.Adam(params, lr=args.lr)


                            for index_atom, individual_atom in enumerate(target_str):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict[index_atom] = V_T_mi_train[0][k]
                            loss = bce(predict, target)
                            loss_i += loss.item()
                            # print('loss is ',loss)
                            print('====================================')
                            optimizer.zero_grad()


                            loss.backward()
                            #softmax_weight =  torch.softmax(NSFR_mi_train.get_params()[0], 1)
                            #print(params[0].grad)
                            #mask = mask_for_gradient(softmax_weight, mask, mask_index)
                            #params[0].grad *= mask

                            '''test task 1'''

                            random_numbers = random.sample(range(2), 2)
                            # target =  [target_all[random_numbers[0]],target_all[random_numbers[1]],target_all[random_numbers[2]],target_all[random_numbers[3]]]
                            test1_all = [target_atom_str, target_atom_str1]
                            # 输出随机数列表
                            print(random_numbers)
                            task1_target = torch.ones(len(test1_all))
                            predict_test = torch.ones(len(test1_all))

                            for index_atom, individual_atom in enumerate(test1_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict_test[index_atom] = V_T_mi_train[0][k]

                            test_loss = bce(predict_test, task1_target)
                            task1_test_loss[iteration].append(test_loss.item())

                            '''test task 2'''
                            # random_numbers = random.sample(range(2), 2)

                            test2_all = [target_atom_str2, target_atom_str3]

                            task2_target = torch.ones(len(test2_all))
                            predict_test = torch.ones(len(test2_all))

                            for index_atom, individual_atom in enumerate(test2_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict_test[index_atom] = V_T_mi_train[0][k]

                            test_loss2 = bce(predict_test, task2_target)
                            task2_test_loss[iteration].append(test_loss2.item())

                            '''test task 3'''

                            test3_atom_all = [target_atom_str4, target_atom_str5, target_atom_str6]

                            task3_target = torch.ones(len(test3_atom_all))
                            test_predict = torch.ones(len(test3_atom_all))
                            for index_atom, individual_atom in enumerate(test3_atom_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        test_predict[index_atom] = V_T_mi_train[0][k]
                            test_loss3 = bce(test_predict, task3_target)
                            print('test_loss3 is', test_loss3.item())
                            task3_test_loss[iteration].append(test_loss3.item())
                            optimizer.step()
                            print("target: ", target_set.detach().cpu().numpy())
                            print('weight is', params)
                            NSFR_mi_train.print_program()
                            for program_fraction in range(len(NSFR_mi_train.store_program())):
                                program[iteration].append(NSFR_mi_train.store_program()[program_fraction])
                            print("loss: ", loss.item())
                            loss_iteration[iteration].append(loss.item())
                    if 0:
                        if i<= 200:# and i <= 7500:
                            target_all = [0, 0, 0, 0, 1, 1, 1]
                            target_str_all = [target_atom_str, target_atom_str1, target_atom_str2, target_atom_str3,
                                              target_atom_str4, target_atom_str5, target_atom_str6]
                            #target_str_all = [target_atom_str4, target_atom_str5, target_atom_str6]

                            "parameters of meta interpreter that needs to be optimized"

                            'use the target reasoner to get the values'

                            # generate
                            # random_numbers = [random.randint(0, 6) for _ in range(4)]

                            random_numbers = random.sample(range(7), 5)
                            # target =  [target_all[random_numbers[0]],target_all[random_numbers[1]],target_all[random_numbers[2]],target_all[random_numbers[3]]]
                            target3_str = [target_str_all[random_numbers[0]], target_str_all[random_numbers[1]], target_str_all[random_numbers[2]], target_str_all[random_numbers[3]], target_str_all[random_numbers[4]]]


                            #target_str_all = [ target_atom_str2, target_atom_str3]#, target_atom_str4, target_atom_str5, target_atom_str6]
                            'use the target reasoner to get the values'
                            # random_numbers = [random.randint(0, 6) for _ in range(4)]

                            '''
                            loss = meta_train(target_all,target_str_all,NSFR_mi_train,imgs):
                            '''





                            #
                            print(random_numbers)
                            target = torch.ones(len(target3_str))
                            for index in range(len(target3_str)):
                                target[index] = target_all[random_numbers[index]]
                            #target[3] = target_all[random_numbers[3]]
                            print(target)

                            'select the target atom'
                            # for a in range(2):
                            loss_i = 0
                            # if a==0:
                            # print(a)
                            V_T_mi_train = NSFR_mi_train(imgs[0])
                            # for a in range(2):
                            #    if a == 0:
                            params = NSFR_mi_train.get_params()
                            "parameters of meta interpreter that needs to be optimized"
                            optimizer = torch.optim.Adam(params, lr=10e-2)
                            predict = torch.ones(len(target3_str))

                            for index_atom, individual_atom in enumerate(target3_str):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict[index_atom] = V_T_mi_train[0][k]


                            loss = bce(predict, target)
                            loss_i += loss.item()
                            # print('loss is ',loss)
                            print('====================================')
                            optimizer.zero_grad()
                            loss.backward()

                            '''test task 1'''

                            random_numbers = random.sample(range(2), 2)
                            # target =  [target_all[random_numbers[0]],target_all[random_numbers[1]],target_all[random_numbers[2]],target_all[random_numbers[3]]]
                            test1_all = [target_atom_str, target_atom_str1]
                            # 输出随机数列表
                            print(random_numbers)
                            task1_target = torch.ones(len(test1_all))
                            predict_test = torch.ones(len(test1_all))

                            for index_atom, individual_atom in enumerate(test1_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict_test[index_atom] = V_T_mi_train[0][k]

                            test_loss = bce(predict_test, task1_target)
                            task1_test_loss[iteration].append(test_loss.item())

                            '''test task 2'''
                            # random_numbers = random.sample(range(2), 2)

                            test2_all = [target_atom_str2, target_atom_str3]
                            #

                            task2_target = torch.ones(len(test2_all))
                            predict_test = torch.ones(len(test2_all))

                            for index_atom, individual_atom in enumerate(test2_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        predict_test[index_atom] = V_T_mi_train[0][k]

                            test_loss2 = bce(predict_test, task2_target)
                            task2_test_loss[iteration].append(test_loss2.item())

                            '''test task 3'''

                            test3_atom_all = [target_atom_str4, target_atom_str5, target_atom_str6]

                            task3_target = torch.ones(len(test3_atom_all))
                            test_predict = torch.ones(len(test3_atom_all))
                            for index_atom, individual_atom in enumerate(test3_atom_all):
                                for k, j in enumerate(train_atoms_str):
                                    if j == individual_atom:
                                        test_predict[index_atom] = V_T_mi_train[0][k]
                            test_loss3 = bce(test_predict, task3_target)
                            print('test_loss3 is', test_loss3.item())
                            task3_test_loss[iteration].append(test_loss3.item())



                            optimizer.step()

                            # NSFR_mi_train.print_valuation_batch(V_T)
                            # print("predicted: ", np.round(predicted.detach().cpu().numpy(), 2))
                            print("target: ", target_set.detach().cpu().numpy())
                            print('weight is', params)
                            NSFR_mi_train.print_program()
                            for program_fraction in range(len(NSFR_mi_train.store_program())):
                                program[iteration].append(NSFR_mi_train.store_program()[program_fraction])
                            print("loss: ", loss.item())
                            loss_iteration[iteration].append(loss.item())
            torch.save(loss_iteration,f'12loss_iteration_{iteration+2}_plot_BFSDFS_learning.pt')
            torch.save(program,f'12learned_{iteration+2}_BFSDFS__learning.pt')
            #torch.save(task1_test_loss,f'12learned_{iteration+1}_task1_continual_loss.pt')
            #torch.save(task2_test_loss,f'12learned_{iteration+1}_task2_continual_loss.pt')
            #torch.save(task3_test_loss,f'12learned_{iteration+1}_task3_continual_loss.pt')
            #
            # if i % 20 == 0:
            #    NSFR.print_valuation_batch(V_T)
            #    print("predicted: ", np.round(predicted.detach().cpu().numpy(), 2))
            #    print("target: ", target_set.detach().cpu().numpy())
            #    NSFR.print_program()
            #    print("loss: ", loss.item())

            #print("Predicting on validation data set...")
            # acc_val, rec_val, th_val = predict(
            #    NSFR, val_loader, args, device, writer, th=0.33, split='val')
            #print("val acc: ", acc_val, "threashold: ", th_val, "recall: ", rec_val)



        loss_list.append(loss_i)
        rtpt.step(subtitle=f"loss={loss_i:2.2f}")
        writer.add_scalar("metric/train_loss", loss_i, global_step=epoch)
        print("loss: ", loss_i)
        # NSFR.print_program()
        if epoch % 1 == 0:
            NSFR.print_program()
            print("Predicting on validation data set...")
            acc_val, rec_val, th_val = predict(
                NSFR, val_loader, args, device, th=0.33, split='val')
            writer.add_scalar("metric/val_acc", acc_val, global_step=epoch)
            print("acc_val: ", acc_val)

            print("Predicting on training data set...")
            acc, rec, th = predict(
                NSFR, train_loader, args, device, th=th_val, split='train')
            writer.add_scalar("metric/train_acc", acc, global_step=epoch)
            print("acc_train: ", acc)

            print("Predicting on test data set...")
            acc, rec, th = predict(
                NSFR, test_loader, args, device, th=th_val, split='train')
            writer.add_scalar("metric/test_acc", acc, global_step=epoch)
            print("acc_test: ", acc)

    return loss


def main(n):
    args = get_args()
    if args.dataset_type == 'kandinsky':
        if args.small_data:
            name = 'small_KP/aILP:' + args.dataset + '_' + str(n)
        else:
            name = 'KP/aILP' + args.dataset + '_' + str(n)
    else:
        if not args.no_xil:
            name = 'CH/aILP:' + args.dataset + '_' + str(n)
        else:
            name = 'CH/aILP-noXIL:' + args.dataset + '_' + str(n)
    print('args ', args)
    if args.no_cuda:
        device = torch.device('cpu')
    elif len(args.device.split(',')) > 1:
        # multi gpu
        device = torch.device('cuda')
    else:
        device = torch.device('cuda:' + args.device)

    print('device: ', device)
    #run_name = 'predict/' + args.dataset
    writer = SummaryWriter(f"runs/{name}", purge_step=0)

    # Create RTPT object
    rtpt = RTPT(name_initials='HS', experiment_name=name,
                max_iterations=args.epochs)
    # Start the RTPT tracking
    rtpt.start()

    # get torch data loader
    train_loader, val_loader,  test_loader = get_data_loader(args)

    train_pos_loader, val_pos_loader, test_pos_loader = get_data_pos_loader(
        args)
    #####train_pos_loader, val_pos_loader, test_pos_loader = get_data_loader(args)

    # load logical representations
    lark_path = 'src/lark/exp.lark'
    lang_base_path = 'data/lang/'
    lang, clauses, bk_clauses, bk, bk2, atoms = get_lang(
        lark_path, lang_base_path, args.dataset_type, args.dataset)

    "construct the atoms and language for the meta interpreter"

    #lang_mi, clauses_mi, bk_clauses_mi, bk_mi, atoms_mi = get_lang_mi(
    #    lark_path, lang_base_path, args.dataset_type, args.dataset)


    print("clauses: ", clauses)

    NSFR = get_nsfr_model(args, lang, clauses, atoms, bk,
                          bk_clauses, device, train=False)
    "construct  meta interpreter"
    #NSFR_mi = get_nsfr_model_mi(args, lang, clauses, atoms, bk,
    #                      bk_clauses, device, train=True)

#    params = NSFR.get_params()
    "parameters of meta interpreter that needs to be optimized"
    #params_mi = NSFR_mi.get_params()

#    optimizer = torch.optim.RMSprop(params, lr=args.lr)

    "meta interpreter optimizer"
    #optimizer_mi = torch.optim.RMSprop(params_mi, lr=args.lr)

    loss_list = train_nsfr(args, NSFR,  train_loader,
                           val_loader, test_loader, device, writer, lark_path, lang_base_path, lang, clauses, atoms, bk, bk2, rtpt)

    # validation split
    print("Predicting on validation data set...")
    acc_val, rec_val, th_val = predict(
        NSFR, val_loader, args, device, th=0.33, split='val')

    print("Predicting on training data set...")
    # training split
    acc, rec, th = predict(
        NSFR, train_loader, args, device, th=th_val, split='train')

    print("Predicting on test data set...")
    # test split
    acc_test, rec_test, th_test = predict(
        NSFR, test_loader, args, device, th=th_val, split='test')

    print("training acc: ", acc, "threashold: ", th, "recall: ", rec)
    print("val acc: ", acc_val, "threashold: ", th_val, "recall: ", rec_val)
    print("test acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)


if __name__ == "__main__":
    for i in range(1):
        main(n=i)
