import matplotlib.pyplot as plt
from rtpt import RTPT
from yolov5.utils.general import non_max_suppression
from yolov5.models.experimental import attempt_load
from dilpst.src.ilp_problem import ILPProblem
from dilpst.src.infer import InferModule
from dilpst.src.tensor_encoder import TensorEncoder
from dilpst.src.data_utils import DataUtils
from dilpst.src.language import Language
from logic_utils import get_index_by_predname
import dilpst.src.logic as lg
from sklearn.metrics import precision_score, accuracy_score, roc_curve
import data
import data_clevr
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import argparse
from datetime import datetime
import time
import pickle

import torch
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing as mp

import scipy.optimize
import numpy as np
from tqdm import tqdm
import matplotlib
from torch.optim import lr_scheduler

import numpy as np
from nelogic import *
from percept import YOLOPerceptionModule, SlotAttentionPerceptionModule
from vrlang import load_language
from facts_converter import FactsConverter
from logic_utils import build_infer_module, generate_atoms, generate_clauses, generate_bk
from vfcr import VFCReasoner
from valuation import *
from torch_utils import select_device


matplotlib.use("Agg")

# import data_copy as data
# import utils as utils


torch.autograd.set_detect_anomaly(True)


def get_args():
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=64, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--e", type=int, default=4, help="The maximum number of objects in one image."
    )
    parser.add_argument(
        "--n-subs", type=int, default=360, help="The maximum number of substitutions for existentially quantified variables."
    )
    parser.add_argument(
        "--dataset",
        choices=["twopairs", "threepairs", "red-triangle", "closeby",
                 "online", "online-pair", "nine-circles", "clevr-hans3", "clevr-hans7"],
        help="Use kandinsky or clevr-hans dataset",
    )
    parser.add_argument(
        "--dataset_type",
        choices=["kandinsky", "clevr"],
        help="kandinsky or clevr",
    )
    parser.add_argument(
        "--perception-model",
        choices=["yolo", "slotattention"],
        help="Choose yolo or slotattention for object recognition.",
    )
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true",
                        help="Use multiple GPUs")

    parser.add_argument("--data-dir", type=str, help="Directory to data")

    parser.add_argument('--program-size', default=5, type=int,
                        help='number of clauses to compose logic programs')
    parser.add_argument('--gamma', default=0.01, type=float,
                        help='smooth parameter in the softor function')

    args = parser.parse_args()
    return args


def compute_acc(outputs, targets):
    # print(outputs.shape)
    # print(targets.shape)
    predicts = np.argmax(outputs, axis=1)
    return accuracy_score(targets, predicts)


def get_prob(v_T, VFCR, args):
    if args.dataset_type == 'kandinsky':
        predicted = VFCR.predict(v=v_T, predname='kp')
    elif args.dataset_type == 'clevr':
        if args.dataset == 'clevr-hans3':
            predicted = VFCR.predict_multi(
                v=v_T, prednames=['kp1', 'kp2', 'kp3'])
        if args.dataset == 'clevr-hans7':
            predicted = VFCR.predict_multi(
                v=v_T, prednames=['kp1', 'kp2', 'kp3', 'kp4', 'kp5', 'kp6', 'kp7'])
    return predicted


def predict_time(VFCR, loader, args, device, writer):
    for i, sample in tqdm(enumerate(loader, start=0)):
        # to cuda
        imgs, target_set = map(lambda x: x.to(device), sample)

        # infer and predict the target probability
        v_T = VFCR(imgs)
        predicted = get_prob(v_T, VFCR, args)
    return predicted


def get_kandinsky_loader(args, batch_size):
    dataset_train = data.KANDINSKY(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data.KANDINSKY(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data.KANDINSKY(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=batch_size,
        num_workers=args.num_workers,
    )

    return train_loader, val_loader, test_loader


def get_clevr_loader(args):
    dataset_train = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data_clevr.CLEVRHans(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    return train_loader, val_loader, test_loader


def get_data_loader(args):
    if args.dataset_type == 'kandinsky':
        return get_kandinsky_loader(args)
    elif args.dataset_type == 'clevr':
        return get_clevr_loader(args)


def get_vfcr_model(args, lang, clauses, atoms, bk, device):
    if args.dataset_type == 'kandinsky':
        PM = YOLOPerceptionModule(
            nn_id='yolo', e=args.e, d=11, device=device)
        VM = YOLOValuationModule(lang=lang,  device=device)
    elif args.dataset_type == 'clevr':
        PM = SlotAttentionPerceptionModule(
            nn_id='slot_attention', e=10, d=19, device=device)
        VM = SlotAttentionValuationModule(lang=lang,  device=device)
    else:
        assert False, "Invalid dataset type: " + str(args.dataset_type)
    FC = FactsConverter(lang=lang, perception_module=PM,
                        valuation_module=VM, device=device)
    IM = build_infer_module(clauses, atoms, lang,
                            m=len(clauses), infer_step=4, max_subs_num=args.n_subs, device=device)
    # Visual forward-chaining reasoner
    VFCR = VFCReasoner(perception_module=PM, facts_converter=FC,
                       infer_module=IM, atoms=atoms, bk=bk, clauses=clauses)
    return VFCR


def get_lang(args):
    du = DataUtils(dataset_type=args.dataset_type, dataset=args.dataset)
    lang = du.load_language()
    print(lang)
    clauses = du.get_clauses(lang)
    print('clauses: ')
    for c in clauses:
        print(c)
    bk = du.get_bk(lang)
    print('bk: ', bk)
    atoms = generate_atoms(lang)
    for i, atom in enumerate(atoms):
        print(i, atom)
    print(len(atoms), 'atoms')
    return lang, clauses, bk, atoms


def main():
    args = get_args()

    # device = select_device(args.device, args.batch_size)
    device = torch.device('cuda:' + args.device)

    print('device: ', device)
    run_name = args.dataset
    writer = SummaryWriter(f"runs/runtime/{run_name}", purge_step=0)
    # writer = None
    # utils.save_args(args, writer)

    # load logical representations
    lang, clauses, bk, atoms = get_lang(args)

    # Visual forward-chaining reasoner
    VFCR = get_vfcr_model(args, lang, clauses, atoms, bk, device)

    runtimes = []
    #batch_sizes = [1, 2, 4, 8, 12, 24, 48]
    #batch_sizes = [1, 2, 4, 8, 12, 24, 48, 64, 80] + list(range(100, 1000, 50))
    #batch_sizes = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    batch_sizes = [1, 5, 10, 15, 20, 15, 30, 35, 40, 45, 50]

    rtpt = RTPT(name_initials='HS', experiment_name='NS-FR/TIME/' + args.dataset,
                max_iterations=len(batch_sizes))
    rtpt.start()
    for bs in batch_sizes:
        print('=== BATCH SIZE ' + str(bs) + ' ===')
        # get torch data loader
        train_loader, val_loader,  test_loader = get_kandinsky_loader(args, bs)
        start = time.time()
        # training step
        predicted = predict_time(
            VFCR, train_loader, args, device, writer)
        end = time.time()
        diff = end - start
        print(end - start)
        runtimes.append(diff)
        writer.add_scalar('runtime with different batch sizes',
                          diff, global_step=bs)
        rtpt.step(subtitle=f"runtime={diff:2.2f}")

    writer.close()

    with open('result/runtime/' + args.dataset + '.pickle', 'wb') as f:
        pickle.dump(runtimes, f)

    print(runtimes)


if __name__ == "__main__":
    main()
