import matplotlib.pyplot as plt
from rtpt import RTPT
from yolov5.utils.general import non_max_suppression
from yolov5.models.experimental import attempt_load
from dilpst.src.infer import InferModule
from dilpst.src.tensor_encoder import TensorEncoder
from dilpst.src.data_utils import DataUtils
from dilpst.src.language import Language
from logic_utils import get_index_by_predname
import dilpst.src.logic as lg
from sklearn.metrics import precision_score, accuracy_score
import data_clevr
from torch.utils.tensorboard import SummaryWriter
import torch
import os
import argparse
from datetime import datetime

import torch
import torch.nn.functional as F

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.multiprocessing as mp

import scipy.optimize
import numpy as np
from tqdm import tqdm
import matplotlib
from torch.optim import lr_scheduler

import numpy as np
from nelogic import *
from percept import YOLOPerceptionModule, SlotAttentionPerceptionModule
from vrlang import load_language
from facts_converter import FactsConverter
from logic_utils import build_infer_module, generate_atoms, generate_clauses, generate_bk
from vfcr import VFCReasoner
from valuation import SlotAttentionValuationModule
from torch_utils import select_device

from valuation_func import SlotAttentionFrontValuationFunction, SlotAttentionRightSideValuationFunction, SlotAttentionLeftSideValuationFunction

matplotlib.use("Agg")


torch.autograd.set_detect_anomaly(True)


def get_args():
    parser = argparse.ArgumentParser()
    # generic params
    parser.add_argument(
        "--name",
        default=datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
        help="Name to store the log file as",
    )
    parser.add_argument("--resume", help="Path to log file to resume from")

    parser.add_argument(
        "--epochs", type=int, default=100, help="Number of epochs to train with"
    )
    parser.add_argument(
        "--ap-log", type=int, default=10, help="Number of epochs before logging AP"
    )
    parser.add_argument(
        "--lr", type=float, default=1e-2, help="Outer learning rate of model"
    )
    parser.add_argument(
        "--batch-size", type=int, default=1, help="Batch size to train with"
    )
    parser.add_argument(
        "--num-workers", type=int, default=4, help="Number of threads for data loader"
    )
    parser.add_argument(
        "--dataset",
        choices=["rightside_pretrain", "leftside_pretrain", "front_pretrain"],
        help="Use Kandinsky Pattern dataset",
    )
    parser.add_argument(
        "--perception-model",
        choices=["yolo", "slotattention"],
        help="Choose yolo or slotattention for object recognition.",
    )
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')

    parser.add_argument(
        "--no-cuda",
        action="store_true",
        help="Run on CPU instead of GPU (not recommended)",
    )
    parser.add_argument(
        "--train-only", action="store_true", help="Only run training, no evaluation"
    )
    parser.add_argument(
        "--eval-only", action="store_true", help="Only run evaluation, no training"
    )
    parser.add_argument("--multi-gpu", action="store_true",
                        help="Use multiple GPUs")

    parser.add_argument("--data-dir", type=str, help="Directory to data")

    parser.add_argument('--program-size', default=5, type=int,
                        help='number of clauses to compose logic programs')
    parser.add_argument('--gamma', default=0.01, type=float,
                        help='smooth parameter in the softor function')

    args = parser.parse_args()
    return args


def compute_acc(outputs, targets):
    # print(outputs.shape)
    # print(targets.shape)
    predicts = np.argmax(outputs, axis=1)
    return accuracy_score(targets, predicts)


def preprocess(data, dataset, device):
    # input z: yolo_output
    # output zs: a sequece for valuation function in neural predicates
    if dataset in ['rightside_pretrain', 'leftside_pretrain']:
        return create_obj_vector_from_xyz(data, device=device)
    if dataset == 'front_pretrain':
        return [create_obj_vector_from_xyz(data[:, :3], device=device), create_obj_vector_from_xyz(data[:, 3:], device=device)]


def create_obj_vector_from_xyz(xyz, dim=19, device=None):
    # batch
    v = torch.zeros((xyz.size(0), dim, ), dtype=torch.float32).to(device)
    v[:, 0] = 1.0  # objectness
    v[:, 1:4] = xyz
    return v


def run(predict_net, loader, optimizer, criterion, writer, args, device, train=False, epoch=0, pool=None, rtpt=None, max_obj_num=4):
    iters_per_epoch = len(loader)
    loss_list = []
    val_loss_list = []

    be = torch.nn.BCELoss()
    # be = torch.nn.BCEWithLogitsLoss()

    output_list = []
    target_set_list = []

    target_inputs = []
    labels = []

    accs = []

    X_train_list = []
    y_train_list = []

    loss_sum, acc_sum = 0, 0
    for i, sample in tqdm(enumerate(loader, start=epoch * iters_per_epoch)):
        zs, target_set = map(lambda x: x.to(device), sample)
        # reset grad
        if train:
            optimizer.zero_grad()
        if args.dataset in ['rightside_pretrain', 'leftside_pretrain']:
            # to cuda
            z = preprocess(zs, args.dataset, device)
            # zs, target_set = map(lambda x: x.to(device), sample)
            predicted = predict_net(z).squeeze()
        elif args.dataset == 'front_pretrain':
            # zs, target_set = map(lambda x: x.to(device), sample)
            zs, target = sample
            z_1, z_2 = preprocess(zs, args.dataset, device)
            predicted = predict_net(z_1, z_2).squeeze()

        # binary cross-entropy loss computation
        loss = be(predicted, target_set)
        loss_sum += loss.item()
        binary_output = np.where(predicted.detach().cpu().numpy() > 0.5, 1, 0)
        # print(binary_output, target_set.detach().cpu().numpy())
        acc = accuracy_score(binary_output, target_set.detach().cpu().numpy())
        acc_sum += acc
        loss_list.append(loss.item())
        loss.backward()
        # update parameters for the step
        if optimizer != None and epoch > 0:
            optimizer.step()

        # print('predicted: ', np.round(predicted.detach().cpu().numpy(), 2))
        # print('labels: ', np.round(target_set.detach().cpu().numpy(), 2))
        # writer.add_scalar("metric/train_loss", loss.item(), global_step=i)
        # print(f"Epoch {epoch} Train Loss: {loss.item()}")

        # if rtpt != None:
        #    rtpt.step(subtitle=f"loss={loss.item():2.2f}")

    mean_loss = loss_sum / len(loader)  # lossの平均を計算
    mean_acc = acc_sum / len(loader)  # accの平均を計算
    # TODO compute std
    if train:
        print('zs: ', zs)
        print('predicted: ', predicted)
        print('target_set: ', target_set)
        print('train mean loss: ', mean_loss)
        print('train mean acc: ', mean_acc)
    else:
        print('val mean loss: ', mean_loss)
        print('val mean acc: ', mean_acc)
    return mean_loss, 0, mean_acc


def get_data_loader(args):
    dataset_train = data_clevr.CLEVRConcept(
        args.dataset,  'train'
    )
    dataset_val = data_clevr.CLEVRConcept(
        args.dataset, 'val'
    )

    train_loader = torch.utils.data.DataLoader(
        dataset_train,
        shuffle=True,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset_val,
        shuffle=False,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
    )

    return train_loader, val_loader


def main():
    args = get_args()

    device = 'cuda:' + args.device
    print('device: ', device)
    run_name = args.dataset
    writer = SummaryWriter(f"runs/np_pretrain/{run_name}", purge_step=0)
    # writer = None
    # utils.save_args(args, writer)

    train_loader, val_loader = get_data_loader(args)
    if args.dataset == 'rightside_pretrain':
        predict_net = SlotAttentionRightSideValuationFunction(device)
    if args.dataset == 'leftside_pretrain':
        predict_net = SlotAttentionLeftSideValuationFunction(device)
    elif args.dataset == 'front_pretrain':
        predict_net = SlotAttentionFrontValuationFunction(device)
    start_epoch = 0

    params = list(predict_net.parameters())
    print('PARAMS: ', params, len(params))
    optimizer = torch.optim.Adam(params, lr=args.lr)
    criterion = torch.nn.SmoothL1Loss()

    # Create RTPT object
    rtpt = RTPT(name_initials='HS',
                experiment_name='NS-FR/NeuralPred:' + args.dataset, max_iterations=args.epochs)
    rtpt.start()

    # train loop
    loss_list = []
    # swtich_flag = True
    for epoch in np.arange(start_epoch, args.epochs + start_epoch):
        with mp.Pool(10) as pool:
            if not args.eval_only:
                # training step
                mean_loss, std_loss, mean_acc = run(
                    predict_net, train_loader, optimizer, criterion, writer, args, device=device, train=True, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/train_loss",
                                  mean_loss, global_step=epoch)
                writer.add_scalar("metric/train_acc",
                                  mean_acc, global_step=epoch)
                rtpt.step(subtitle=f"loss={mean_loss:2.2f}")

                cur_lr = optimizer.param_groups[0]["lr"]
                writer.add_scalar(
                    "lr", cur_lr, global_step=epoch * len(train_loader))
                # scheduler.step()

                # validation step
                mean_loss_val, std_loss_val, mean_acc_val = run(
                    predict_net, val_loader, None, criterion, writer, args, device=device, train=False, epoch=epoch, pool=pool, rtpt=rtpt)
                writer.add_scalar("metric/val_loss",
                                  mean_loss_val, global_step=epoch)
                writer.add_scalar("metric/val_acc",
                                  mean_acc_val, global_step=epoch)

        # save mlp weights for neural preficate (valuation function)
        if epoch % 10 == 0:
            torch.save(predict_net.state_dict(),
                       'weights/' + args.dataset + '.pt')


if __name__ == "__main__":
    main()
