import matplotlib.pyplot as plt
from rtpt import RTPT
from yolov5.utils.general import non_max_suppression
from yolov5.models.experimental import attempt_load
from dilpst.src.ilp_problem import ILPProblem
from dilpst.src.infer import InferModule
from dilpst.src.tensor_encoder import TensorEncoder
from dilpst.src.data_utils import DataUtils
from dilpst.src.language import Language
from logic_utils import get_index_by_predname
import dilpst.src.logic as lg
from sklearn.metrics import precision_score, accuracy_score, roc_curve, recall_score
import data
import data_clevr
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import argparse
from datetime import datetime

import torch
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing as mp

import scipy.optimize
import numpy as np
from tqdm import tqdm
import matplotlib
from torch.optim import lr_scheduler

import numpy as np
from nelogic import *
from percept import YOLOPerceptionModule, SlotAttentionPerceptionModule
from vrlang import load_language
from facts_converter import FactsConverter
from logic_utils import build_infer_module, generate_atoms, generate_clauses, generate_bk
from vfcr import VFCReasoner
from valuation import *
from torch_utils import select_device


matplotlib.use("Agg")

# import data_copy as data
# import utils as utils


torch.autograd.set_detect_anomaly(True)


def get_args():
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=64, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--e", type=int, default=4, help="The maximum number of objects in one image."
    )
    parser.add_argument(
        "--n-subs", type=int, default=360, help="The maximum number of substitutions for existentially quantified variables."
    )
    parser.add_argument(
        "--dataset",
        choices=["twopairs", "threepairs", "red-triangle", "closeby",
                 "online", "online-pair", "nine-circles", "clevr-hans3", "clevr-hans7"],
        help="Use kandinsky or clevr-hans dataset",
    )
    parser.add_argument(
        "--dataset_type",
        choices=["kandinsky", "clevr"],
        help="kandinsky or clevr",
    )
    parser.add_argument(
        "--perception-model",
        choices=["yolo", "slotattention"],
        help="Choose yolo or slotattention for object recognition.",
    )
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true",
                        help="Use multiple GPUs")

    parser.add_argument("--data-dir", type=str, help="Directory to data")

    parser.add_argument('--program-size', default=5, type=int,
                        help='number of clauses to compose logic programs')
    parser.add_argument('--gamma', default=0.01, type=float,
                        help='smooth parameter in the softor function')

    args = parser.parse_args()
    return args


def compute_acc(outputs, targets):
    # print(outputs.shape)
    # print(targets.shape)
    predicts = np.argmax(outputs, axis=1)
    return accuracy_score(targets, predicts)


def denormalize(imgs):
    # denormalize clevr images
    # normalizing: image = (image - 0.5) * 2.0  # Rescale to [-1, 1].
    return (0.5 * imgs) + 0.5


def get_prob(v_T, VFCR, args):
    if args.dataset_type == 'kandinsky':
        predicted = VFCR.predict(v=v_T, predname='kp')
    elif args.dataset_type == 'clevr':
        if args.dataset == 'clevr-hans3':
            predicted = VFCR.predict_multi(
                v=v_T, prednames=['kp1', 'kp2', 'kp3'])
        if args.dataset == 'clevr-hans7':
            predicted = VFCR.predict_multi(
                v=v_T, prednames=['kp1', 'kp2', 'kp3', 'kp4', 'kp5', 'kp6', 'kp7'])
    return predicted


def predict(VFCR, loader, args, device, writer, th=None):
    iters_per_epoch = len(loader)

    output_list = []
    target_set_list = []

    target_inputs = []
    labels = []

    accs = []

    loss_sum, acc_sum = 0, 0
    predicted_list = []
    target_list = []
    for i, sample in tqdm(enumerate(loader, start=0)):
        # to cuda
        imgs, target_set = map(lambda x: x.to(device), sample)

        # infer and predict the target probability
        v_T = VFCR(imgs)
        predicted = get_prob(v_T, VFCR, args)

        predicted_list.extend(
            list(np.argmax(predicted.detach().cpu().numpy(), axis=0)))
        target_set_list.extend(
            list(np.argmax(target_set.detach().cpu().numpy(), axis=1)))

        #binary_output = np.where(predicted.detach().cpu().numpy() > 0.3, 1, 0)
        #acc = accuracy_score(binary_output, target_set.detach().cpu().numpy())
        # acc_sum += acc
        if i < 10:
            if args.dataset_type == 'clevr':
                writer.add_images(
                    'images', denormalize(imgs).detach().cpu(), 0)
            else:
                writer.add_images(
                    'images', imgs.detach().cpu(), 0)
            writer.add_text('v_T', VFCR.get_valuation_text(v_T), 0)

    #predicted = np.concatenate(predicted_list, axis=None)
    #target_set = np.concatenate(target_set_list, axis=None)
    predicted = predicted_list
    target_set = target_set_list

    print('predicted: ', predicted)
    print('target_set: ', target_set)

    if th == None:
        fpr, tpr, thresholds = roc_curve(target_set, predicted)
        accuracy_scores = []
        for thresh in thresholds:
            accuracy_scores.append(accuracy_score(
                target_set, [m > thresh for m in predicted]))

        accuracies = np.array(accuracy_scores)
        max_accuracy = accuracies.max()
        max_accuracy_threshold = thresholds[accuracies.argmax()]
        rec_score = recall_score(
            target_set,  [m > thresh for m in predicted], average=None)

        # mean_loss = loss_sum / len(loader)
        # mean_acc = acc_sum / len(loader)
        print('target_set: ', target_set, target_set.shape)
        print('predicted: ', predicted, predicted.shape)
        print('accuracy: ', max_accuracy)
        print('threshold: ', max_accuracy_threshold)
        print('recall: ', rec_score)

        return max_accuracy, rec_score, max_accuracy_threshold
    else:
        accuracy = accuracy_score(target_set, [m > th for m in predicted])
        rec_score = recall_score(
            target_set,  [m > thresh for m in predicted], average=None)
        return accuracy, rec_score, th


def get_kandinsky_loader(args):
    dataset_train = data.KANDINSKY(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data.KANDINSKY(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data.KANDINSKY(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    return train_loader, val_loader, test_loader


def get_clevr_loader(args):
    dataset_train = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    return train_loader, val_loader, test_loader


def get_data_loader(args):
    if args.dataset_type == 'kandinsky':
        return get_kandinsky_loader(args)
    elif args.dataset_type == 'clevr':
        return get_clevr_loader(args)


def get_vfcr_model(args, lang, clauses, atoms, bk, device):
    if args.dataset_type == 'kandinsky':
        PM = YOLOPerceptionModule(
            nn_id='yolo', e=args.e, d=11, device=device)
        VM = YOLOValuationModule(lang=lang,  device=device)
    elif args.dataset_type == 'clevr':
        PM = SlotAttentionPerceptionModule(
            nn_id='slot_attention', e=10, d=19, device=device)
        VM = SlotAttentionValuationModule(lang=lang,  device=device)
    else:
        assert False, "Invalid dataset type: " + str(args.dataset_type)
    FC = FactsConverter(lang=lang, perception_module=PM,
                        valuation_module=VM, device=device)
    IM = build_infer_module(clauses, atoms, lang,
                            m=len(clauses), infer_step=4, max_subs_num=args.n_subs, device=device)
    # Visual forward-chaining reasoner
    VFCR = VFCReasoner(perception_module=PM, facts_converter=FC,
                       infer_module=IM, atoms=atoms, bk=bk, clauses=clauses)
    return VFCR


def get_lang(args):
    du = DataUtils(dataset_type=args.dataset_type, dataset=args.dataset)
    lang = du.load_language()
    print(lang)
    clauses = du.get_clauses(lang)
    print('clauses: ')
    for c in clauses:
        print(c)
    bk = du.get_bk(lang)
    print('bk: ', bk)
    atoms = generate_atoms(lang)
    for i, atom in enumerate(atoms):
        print(i, atom)
    print(len(atoms), 'atoms')
    return lang, clauses, bk, atoms


def main():
    args = get_args()

    # device = select_device(args.device, args.batch_size)
    print('args ', args)
    if len(args.device.split(',')) > 1:
        # multi gpu
        device = torch.device('cuda')
    else:
        device = torch.device('cuda:' + args.device)

    print('device: ', device)
    run_name = 'predict/' + args.dataset
    #run_name = args.name + '_γ=' + str(args.gamma)
    writer = SummaryWriter(f"runs/{run_name}", purge_step=0)
    # writer = None
    # utils.save_args(args, writer)

    # get torch data loader
    train_loader, val_loader,  test_loader = get_data_loader(args)

    # load logical representations
    lang, clauses, bk, atoms = get_lang(args)

    # Visual forward-chaining reasoner
    VFCR = get_vfcr_model(args, lang, clauses, atoms, bk, device)
    if len(args.device.split(',')) > 1:
        VFCR = nn.DataParallel(VFCR)

    # validation step
    acc_val, rec_val, th_val = predict(
        VFCR, val_loader, args, device, writer, th=None)

    # training step
    acc, rec, th = predict(
        VFCR, train_loader, args, device, writer, th=th_val)

    # writer.add_scalar(
    #    "lr", cur_lr, global_step=epoch * len(train_loader))
    # scheduler.step()

    # test step
    acc_test, rec_test, th_test = predict(
        VFCR, test_loader, args, device, writer, th=th_val)

    print("training acc: ", acc, "threashold: ", th, "recall: ", rec)
    print("val acc: ", acc_val, "threashold: ", th_val, "recall: ", rec_val)
    print("test acc: ", acc_test, "threashold: ", th_test, "recall: ", rec_test)


if __name__ == "__main__":
    main()
