import matplotlib.pyplot as plt
from rtpt import RTPT
from yolov5.utils.general import non_max_suppression
from yolov5.models.experimental import attempt_load
from dilpst.src.ilp_problem import ILPProblem
from dilpst.src.infer import InferModule
from dilpst.src.tensor_encoder import TensorEncoder
from dilpst.src.data_utils import DataUtils
from dilpst.src.language import Language
from logic_utils import get_index_by_predname
import dilpst.src.logic as lg
from sklearn.metrics import precision_score, accuracy_score
import data
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import argparse
from datetime import datetime

import torch
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing as mp
import torchvision.models as models

import scipy.optimize
import numpy as np
from tqdm import tqdm
import matplotlib
from torch.optim import lr_scheduler

import numpy as np
from nelogic import *
from percept import YOLOPerceptionModule
from vrlang import load_language
from facts_converter import FactsConverter
from logic_utils import build_infer_module, generate_atoms, generate_clauses
from vfcr import VFCReasoner
from valuation import *
from torch_utils import select_device
from neural_utils import LogisticRegression

matplotlib.use("Agg")

# import data_copy as data
# import utils as utils


torch.autograd.set_detect_anomaly(True)


def get_args():
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument(
        "--epochs", type=int, default=10, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--dataset",
        choices=["twopairs", "threepairs", "red-triangle", "closeby",
                 "online", "online-pair", "nine-circles"],
        help="Use MNIST dataset",
    )
    parser.add_argument(
        "--perception-model",
        choices=["yolo", "slotattention"],
        help="Choose yolo or slotattention for object recognition.",
    )
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true",
                        help="Use multiple GPUs")

    parser.add_argument("--data-dir", type=str, help="Directory to data")
    # Slot attention params
    parser.add_argument('--n-slots', default=10, type=int,
                        help='number of slots for slot attention module')
    parser.add_argument('--n-iters-slot-att', default=3, type=int,
                        help='number of iterations in slot attention module')
    parser.add_argument('--n-attr', default=18, type=int,
                        help='number of attributes per object')
    parser.add_argument('--program-size', default=5, type=int,
                        help='number of clauses to compose logic programs')

    args = parser.parse_args()
    return args


def compute_acc(outputs, targets):
    # print(outputs.shape)
    # print(targets.shape)
    predicts = np.argmax(outputs, axis=1)
    return accuracy_score(targets, predicts)


def run(net, predict_net,  loader, optimizer, criterion, writer, args, device, train=False, epoch=0, pool=None, rtpt=None, max_obj_num=4):
    iters_per_epoch = len(loader)
    loss_list = []
    val_loss_list = []

    be = torch.nn.BCELoss()
    # be = torch.nn.BCEWithLogitsLoss()

    output_list = []
    target_set_list = []

    target_inputs = []
    labels = []

    accs = []

    X_train_list = []
    y_train_list = []

    loss_sum, acc_sum = 0, 0
    for i, sample in tqdm(enumerate(loader, start=epoch * iters_per_epoch)):
        # to cuda
        imgs, target_set = map(lambda x: x.to(device), sample)
        # reset grad
        if train:
            optimizer.zero_grad()

        # infer and predict the target probability
        x = net(imgs)
        predicted = predict_net(x).squeeze()

        # binary cross-entropy loss computation
        loss = be(predicted, target_set)
        loss_sum += loss.item()
        binary_output = np.where(predicted.detach().cpu().numpy() > 0.5, 1, 0)
        #print(binary_output, target_set.detach().cpu().numpy())
        acc = accuracy_score(binary_output, target_set.detach().cpu().numpy())
        acc_sum += acc
        loss_list.append(loss.item())
        loss.backward()
        # update parameters for the step
        if optimizer != None and epoch > 0:
            optimizer.step()

        #print('predicted: ', np.round(predicted.detach().cpu().numpy(), 2))
        #print('labels: ', np.round(target_set.detach().cpu().numpy(), 2))
        # writer.add_scalar("metric/train_loss", loss.item(), global_step=i)
        #print(f"Epoch {epoch} Train Loss: {loss.item()}")

        # if rtpt != None:
        #    rtpt.step(subtitle=f"loss={loss.item():2.2f}")

    mean_loss = loss_sum / len(loader)  # lossの平均を計算
    mean_acc = acc_sum / len(loader)  # accの平均を計算
    # TODO compute std
    if train:
        print('train mean loss: ', mean_loss)
        print('train mean acc: ', mean_acc)
    else:
        print('val mean loss: ', mean_loss)
        print('val mean acc: ', mean_acc)
    return mean_loss, 0, mean_acc


def get_data_loader(args):
    dataset_train = data.KANDINSKY(
        args.dataset, args.data_dir, 'train'
    )
    dataset_val = data.KANDINSKY(
        args.dataset, args.data_dir, 'val'
    )
    dataset_test = data.KANDINSKY(
        args.dataset, args.data_dir, 'test'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    test_loader = torch.utils.data.DataLoader(
        dataset_test,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    return train_loader, val_loader, test_loader


def main():
    args = get_args()

    #device = select_device(args.device, args.batch_size)
    device = 'cuda:' + args.device
    name = args.name + ':' + args.dataset
    writer = SummaryWriter(f"runs/{name}", purge_step=0)
    # writer = None
    # utils.save_args(args, writer)

    # get torch data loader
    train_loader, val_loader, test_loader = get_data_loader(args)

    start_epoch = 0
    resnet = models.resnet50(pretrained=True)
    resnet.to(device)
    predict_net = LogisticRegression(input_dim=1000)
    predict_net.to(device)

    # setting optimizer
    #params = VFCR.get_params()
    params = list(resnet.parameters()) + list(predict_net.parameters())
    #params = list(VFCR.fc.parameters())
    #print('parameters im: ', list(VFCR.im.parameters()))
    #print('parameters fc: ', list(VFCR.fc.parameters()))
    optimizer = torch.optim.Adam(params, lr=args.lr)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #    optimizer, args.epochs, eta_min=0.00005)
    criterion = torch.nn.SmoothL1Loss()

    # Create RTPT object
    rtpt = RTPT(name_initials='HS',
                experiment_name='VFCR/ResNet:'+args.dataset, max_iterations=args.epochs)
    rtpt.start()

    # train loop
    loss_list = []
    for epoch in np.arange(start_epoch, args.epochs + start_epoch):
        with mp.Pool(10) as pool:
            if not args.eval_only:
                # training step
                mean_loss, std_loss, mean_acc = run(
                    resnet, predict_net, train_loader, optimizer, criterion, writer, args, device=device, train=True, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/train_loss",
                                  mean_loss, global_step=epoch)
                writer.add_scalar("metric/train_acc",
                                  mean_acc, global_step=epoch)
                rtpt.step(subtitle=f"loss={mean_loss:2.2f}")

                cur_lr = optimizer.param_groups[0]["lr"]
                writer.add_scalar(
                    "lr", cur_lr, global_step=epoch * len(train_loader))
                # scheduler.step()

                # validation step
                mean_loss_val, std_loss_val, mean_acc_val = run(
                    resnet, predict_net, val_loader, None, criterion, writer, args, device=device, train=False, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/val_loss",
                                  mean_loss_val, global_step=epoch)
                writer.add_scalar("metric/val_acc",
                                  mean_acc_val, global_step=epoch)

                # test step
                mean_loss_test, std_loss_test, mean_acc_test = run(
                    resnet, predict_net, test_loader, None, criterion, writer, args, device=device, train=False, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/test_loss",
                                  mean_loss_test, global_step=epoch)
                writer.add_scalar("metric/test_acc",
                                  mean_acc_test, global_step=epoch)


if __name__ == "__main__":
    main()
