import matplotlib.pyplot as plt
from rtpt import RTPT
from yolov5.utils.general import non_max_suppression
from yolov5.models.experimental import attempt_load
from dilpst.src.ilp_problem import ILPProblem
from dilpst.src.infer import InferModule
from dilpst.src.tensor_encoder import TensorEncoder
from dilpst.src.data_utils import DataUtils
from dilpst.src.language import Language
from logic_utils import get_index_by_predname
import dilpst.src.logic as lg
from sklearn.metrics import precision_score, accuracy_score
import data
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import argparse
from datetime import datetime

import torch
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing as mp

import scipy.optimize
import numpy as np
from tqdm import tqdm
import matplotlib
from torch.optim import lr_scheduler

import numpy as np
from nelogic import *
from percept import YOLOPerceptionModule
from vrlang import load_language
from facts_converter import FactsConverter
from logic_utils import build_infer_module, generate_atoms, generate_clauses, generate_bk
from vfcr import VFCReasoner
from valuation import *
from torch_utils import select_device

from valuation_func import YOLOOnlineValuationFunction, YOLOClosebyValuationFunction

matplotlib.use("Agg")


torch.autograd.set_detect_anomaly(True)


def get_args():
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--dataset",
        choices=["closeby_pretrain", "online_pretrain"],
        help="Use Kandinsky Pattern dataset",
    )
    parser.add_argument(
        "--perception-model",
        choices=["yolo", "slotattention"],
        help="Choose yolo or slotattention for object recognition.",
    )
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true",
                        help="Use multiple GPUs")

    parser.add_argument("--data-dir", type=str, help="Directory to data")

    parser.add_argument('--program-size', default=5, type=int,
                        help='number of clauses to compose logic programs')
    parser.add_argument('--gamma', default=0.01, type=float,
                        help='smooth parameter in the softor function')

    args = parser.parse_args()
    return args


def compute_acc(outputs, targets):
    # print(outputs.shape)
    # print(targets.shape)
    predicts = np.argmax(outputs, axis=1)
    return accuracy_score(targets, predicts)


def preprocess(z, dataset):
    # input z: yolo_output
    # output zs: a sequece for valuation function in neural predicates
    if dataset == 'closeby_pretrain':
        return [z[:, 0], z[:, 1]]
    if dataset == 'online_pretrain':
        return [z[:, 0], z[:, 1], z[:, 2], z[:, 3], z[:, 4]]


def run(net, predict_net, loader, optimizer, criterion, writer, args, device, train=False, epoch=0, pool=None, rtpt=None, max_obj_num=4):
    iters_per_epoch = len(loader)
    loss_list = []
    val_loss_list = []

    be = torch.nn.BCELoss()
    # be = torch.nn.BCEWithLogitsLoss()

    output_list = []
    target_set_list = []

    target_inputs = []
    labels = []

    accs = []

    X_train_list = []
    y_train_list = []

    loss_sum, acc_sum = 0, 0
    for i, sample in tqdm(enumerate(loader, start=epoch * iters_per_epoch)):
        # to cuda
        imgs, target_set = map(lambda x: x.to(device), sample)
        # reset grad
        if train:
            optimizer.zero_grad()

        # yolo net to predict each object
        x = net(imgs)
        zs = preprocess(x, args.dataset)
        predicted = predict_net(*zs).squeeze()

        # binary cross-entropy loss computation
        loss = be(predicted, target_set)
        loss_sum += loss.item()
        binary_output = np.where(predicted.detach().cpu().numpy() > 0.5, 1, 0)
        #print(binary_output, target_set.detach().cpu().numpy())
        acc = accuracy_score(binary_output, target_set.detach().cpu().numpy())
        acc_sum += acc
        loss_list.append(loss.item())
        loss.backward()
        # update parameters for the step
        if optimizer != None and epoch > 0:
            optimizer.step()

        #print('predicted: ', np.round(predicted.detach().cpu().numpy(), 2))
        #print('labels: ', np.round(target_set.detach().cpu().numpy(), 2))
        # writer.add_scalar("metric/train_loss", loss.item(), global_step=i)
        #print(f"Epoch {epoch} Train Loss: {loss.item()}")

        # if rtpt != None:
        #    rtpt.step(subtitle=f"loss={loss.item():2.2f}")

    mean_loss = loss_sum / len(loader)  # lossの平均を計算
    mean_acc = acc_sum / len(loader)  # accの平均を計算
    # TODO compute std
    if train:
        print('yolo_output: ', x)
        print('preprocessed: ', zs)
        print('diff; ', predict_net.diff)
        print('predicted: ', predicted)
        print('target_set: ', target_set)
        print('train mean loss: ', mean_loss)
        print('train mean acc: ', mean_acc)
    else:
        print('val mean loss: ', mean_loss)
        print('val mean acc: ', mean_acc)
    return mean_loss, 0, mean_acc


def get_data_loader(args):
    dataset_train = data.KANDINSKY(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data.KANDINSKY(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data.KANDINSKY(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    return train_loader, val_loader, test_loader


def main():
    args = get_args()

    device = 'cuda:' + args.device
    print('device: ', device)
    run_name = args.name + '_γ=' + str(args.gamma)
    writer = SummaryWriter(f"runs/np_pretrain/{run_name}", purge_step=0)
    # writer = None
    # utils.save_args(args, writer)

    # get torch data loader
    train_loader, val_loader,  test_loader = get_data_loader(args)

    if args.dataset == 'closeby_pretrain':
        yolo_net = YOLOPerceptionModule(nn_id='yolo', e=2, d=11, device=device)
        predict_net = YOLOClosebyValuationFunction(device)
    elif args.dataset == 'online_pretrain':
        yolo_net = YOLOPerceptionModule(nn_id='yolo', e=5, d=11, device=device)
        predict_net = YOLOOnlineValuationFunction(device)
    start_epoch = 0

    params = list(predict_net.parameters())
    print('PARAMS: ', params, len(params))
    optimizer = torch.optim.Adam(params, lr=args.lr)
    criterion = torch.nn.SmoothL1Loss()

    # Create RTPT object
    rtpt = RTPT(name_initials='HS',
                experiment_name='VFCR/NeuralPred:' + args.dataset, max_iterations=args.epochs)
    rtpt.start()

    # train loop
    loss_list = []
    #swtich_flag = True
    for epoch in np.arange(start_epoch, args.epochs + start_epoch):
        with mp.Pool(10) as pool:
            if not args.eval_only:
                # training step
                mean_loss, std_loss, mean_acc = run(
                    yolo_net, predict_net, train_loader, optimizer, criterion, writer, args, device=device, train=True, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/train_loss",
                                  mean_loss, global_step=epoch)
                writer.add_scalar("metric/train_acc",
                                  mean_acc, global_step=epoch)
                rtpt.step(subtitle=f"loss={mean_loss:2.2f}")

                cur_lr = optimizer.param_groups[0]["lr"]
                writer.add_scalar(
                    "lr", cur_lr, global_step=epoch * len(train_loader))
                # scheduler.step()

                # validation step
                mean_loss_val, std_loss_val, mean_acc_val = run(
                    yolo_net, predict_net, val_loader, None, criterion, writer, args, device=device, train=False, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/val_loss",
                                  mean_loss_val, global_step=epoch)
                writer.add_scalar("metric/val_acc",
                                  mean_acc_val, global_step=epoch)

                # test step
                mean_loss_test, std_loss_test, mean_acc_test = run(
                    yolo_net, predict_net, test_loader, None, criterion, writer, args, device=device, train=False, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/test_loss",
                                  mean_loss_test, global_step=epoch)
                writer.add_scalar("metric/test_acc",
                                  mean_acc_test, global_step=epoch)

        # save mlp weights for neural preficate (valuation function)
        if epoch % 10 == 0:
            torch.save(predict_net.state_dict(),
                       'weights/' + args.dataset + '.pt')


def plot_grad(named_parameters, writer, epoch=0):
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n) and (p.grad != None):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().detach().cpu().numpy())
            max_grads.append(p.grad.abs().max().detach().cpu().numpy())

            ave_grad = (p.grad.abs().mean().detach().cpu().numpy())
            writer.add_scalar("grad/" + n,
                              ave_grad, global_step=epoch)


def plot_grad_flow(named_parameters, epoch=0):
    from matplotlib.lines import Line2D
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.

    Usage: Plug this function in Trainer class after loss.backwards() as
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads = []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().detach().cpu().numpy())
            max_grads.append(p.grad.abs().max().detach().cpu().numpy())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom=-0.001, top=0.02)  # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([Line2D([0], [0], color="c", lw=4),
                Line2D([0], [0], color="b", lw=4),
                Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
    plt.savefig('figures/gradient_epoch' +
                str(epoch) + '.png', bbox_inches='tight')


if __name__ == "__main__":
    main()
