"""Train and test concept learning with EBMs."""

import argparse
import os
import os.path as osp

import pkbar
import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader

from src.classifier_dataset import ReferIt3DClassDataset
from src.classifiers import RelClassifier


def train_(model, data_loaders, args):
    """Train a classifier."""
    # Setup
    device = args.device
    optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.0, 0.999))
    start_epoch = 0

    if osp.exists(args.ckpnt):
        checkpoint = torch.load(args.ckpnt, map_location=args.device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print("Epoch: %d/%d" % (epoch + 1, args.epochs))
        kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25)
        for step, ex in enumerate(data_loaders['train']):

            # Forward pass
            out = model(
                ex['target_boxes'].to(device),
                ex['anchor_boxes'].to(device)
            )  # (B, N, N2, 1)
            out = out * ex['a_mask'].to(device).unsqueeze(1).unsqueeze(-1)
            out = out.sum(2)  # (B, N, 1)
            out = out.reshape(-1)
            out = out[ex['t_mask'].reshape(-1) > 0].float()

            # Backward pass
            labels = ex['labels'].reshape(-1)
            labels = labels[ex['t_mask'].reshape(-1) > 0].to(device)
            loss = (
                F.binary_cross_entropy_with_logits(out, labels.float())
                + F.kl_div(out.log_softmax(-1), labels.float())
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            kbar.update(step, [("loss", loss)])

        # Evaluation and model storing
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict()
            },
            args.ckpnt
        )
        print("\nValidation")
        eval_(model, data_loaders['test'], args)
    return model


@torch.no_grad()
def eval_(model, data_loader, args):
    """Evaluate model on val/test data."""
    model.eval()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    num_correct = 0
    num_examples = 0
    for step, ex in enumerate(data_loader):
        out = model(
            ex['target_boxes'].to(device),
            ex['anchor_boxes'].to(device)
        )  # (B, N, N2, 1)
        out = out * ex['a_mask'].to(device).unsqueeze(1).unsqueeze(-1)
        out = out.sum(2)  # (B, N, 1)
        out = out.reshape(-1)
        out = out[ex['t_mask'].reshape(-1) > 0]

        pred = out > 0
        labels = ex['labels'].reshape(-1)
        labels = labels[ex['t_mask'].reshape(-1) > 0].to(device)
        num_correct += (pred == labels).sum().item()
        num_examples += len(labels)

        kbar.update(
            step,
            [
                ("accuracy1", num_correct / num_examples)
            ]
        )
    print(
        "\nAccuracies:",
        num_correct / num_examples
    )
    model = model.train()
    return num_correct / num_examples


def main():
    """Run main training/test pipeline."""
    data_path = "./dataset/language_grounding/"
    if not osp.exists(data_path):
        data_path = 'data/'  # or change this if you work locally

    # Parse arguments
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--checkpoint_path", default="checkpoints/")
    argparser.add_argument("--logs_dir", default="runs/")
    argparser.add_argument("--checkpoint", default="below_classifier.pt")
    argparser.add_argument("--epochs", default=100, type=int)
    argparser.add_argument("--batch_size", default=128, type=int)
    argparser.add_argument("--lr", default=1e-3, type=float)
    args = argparser.parse_args()
    args.ckpnt = osp.join(
        args.checkpoint_path, args.checkpoint
    )
    args.device = torch.device(
        'cuda:0' if torch.cuda.is_available() else 'cpu'
    )
    # args.device = 'cpu'
    print(args.device)
    os.makedirs(args.checkpoint_path, exist_ok=True)

    # Data loaders
    data_loaders = {
        mode: DataLoader(
            ReferIt3DClassDataset('sr3d', mode, 'below'),
            batch_size=args.batch_size,
            shuffle=mode == 'train',
            drop_last=mode == 'train',
            num_workers=4
        )
        for mode in ('train', 'test')
    }

    # Train EBM
    model = train_(
        RelClassifier().to(args.device),
        data_loaders, args
    )
    eval_(model, data_loaders['test'], args)


if __name__ == "__main__":
    main()
