"""Train and test classification on ScanNet."""

import argparse
import json
import os
import os.path as osp

import numpy as np
import pkbar
import torch
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
import wandb

from models.backbone_module import PointnetPPClass
from src.scannet_cls_dataset import ScanNetClsDataset, NUM_CLASSES


def train_classifier(model, data_loaders, args):
    """Train a 3d object classifier."""
    # Setup
    device = args.device
    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    start_epoch = 0
    best_acc = 0
    if osp.exists(args.ckpnt):
        checkpoint = torch.load(args.ckpnt)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]
        best_acc = checkpoint["acc"]
    if start_epoch >= args.epochs:
        return model
    if args.eval:
        return model

    # Training loop
    for epoch in range(start_epoch, args.epochs):
        print("Epoch: %d/%d" % (epoch + 1, args.epochs))
        kbar = pkbar.Kbar(target=len(data_loaders['train']), width=25)
        model.train()
        num_correct = 0  # for micro-accuracy
        num_examples = 0
        total_loss = 0
        for step, ex in enumerate(data_loaders['train']):
            labels = ex["obj_labels"].to(device)
            logits = model(ex["point_clouds"].to(device))
            loss = F.cross_entropy(logits, labels)
            num_correct += (logits.argmax(1).cpu() == ex["obj_labels"]).sum()
            num_examples += len(logits)
            total_loss += loss.item()
            kbar.update(step, [("loss", loss)])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        wandb.log(
            {f"training_loss": total_loss / len(data_loaders['train'])},
            step=epoch
        )
        wandb.log(
            {f"training_acc": num_correct.item() / num_examples},
            step=epoch
        )
        print("\nTraining accuracy:", num_correct.item() / num_examples)

        # Evaluation
        print("\nValidation")
        micro, macro, loss = eval_classifier(model, data_loaders['test'], args)
        wandb.log({f"micro_acc": micro}, step=epoch)
        wandb.log({f"macro_acc": macro}, step=epoch)
        wandb.log({f"val_loss": loss}, step=epoch)

        # Store
        if micro >= best_acc:
            torch.save(
                {
                    "epoch": epoch + 1,
                    "acc": micro,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict()
                },
                args.ckpnt
            )
        else:  # load checkpoint to update scheduler and epoch
            checkpoint = torch.load(args.ckpnt)
            checkpoint["epoch"] += 1
            torch.save(checkpoint, args.ckpnt)
    # Test
    test_acc = eval_classifier(model, data_loaders['test'], args)
    print(f"Test Accuracy: {test_acc}")
    return model


@torch.no_grad()
def eval_classifier(model, data_loader, args):
    """Evaluate model on val/test data."""
    model.eval()
    device = args.device
    kbar = pkbar.Kbar(target=len(data_loader), width=25)
    num_correct = {obj: 0 for obj in range(model.num_classes)}
    num_examples = {obj: 0 for obj in range(model.num_classes)}
    total_loss = 0
    results = {}
    for step, ex in enumerate(data_loader):
        labels = ex["obj_labels"].to(device)
        logits = model(ex["point_clouds"].to(device))
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item()
        logits = logits.cpu().numpy()

        # Update accuracy stats
        for scan_id, obj_id, logit in zip(ex['scan_ids'], ex['object_ids'], logits):
            if scan_id not in results:
                results[scan_id] = (-np.ones(256)).tolist()
            results[scan_id][obj_id] = logit.argmax().item()
        for logit, label in zip(logits.argmax(1), ex["obj_labels"].numpy()):
            num_examples[label] += 1
            if logit == label:
                num_correct[label] += 1

        micro = sum(num_correct.values()) / sum(num_examples.values())
        kbar.update(step, [("accuracy", micro)])

    # Eval accuracies
    micro = sum(num_correct.values()) / sum(num_examples.values())
    accuracies = [
        num_correct[obj] / num_examples[obj]
        for obj in num_examples.keys() if num_examples[obj]
    ]
    macro = np.mean(accuracies)
    print(f"\nAccuracy: {micro}, macroAccuracy: {macro}")
    with open('cls_results_train.json', 'w') as fid:
        json.dump(results, fid)

    return micro, macro, total_loss / len(data_loader)


def main(args):
    """Run main training/test pipeline."""
    # Data loaders for classification
    data_loaders = {
        mode: DataLoader(
            ScanNetClsDataset(
                mode if not args.overfit else 'test',
                num_points=2048,
                use_color=args.use_color, use_size=args.use_size,
                use_height=args.use_height, use_normals=args.use_normals,
                overfit=args.overfit
            ),
            batch_size=args.batch_size,
            shuffle=mode == 'train',
            drop_last=False,  # mode == 'train',
            num_workers=4
        )
        for mode in ('train', 'test')
    }

    in_dim = 0  # xyz features don't count in total
    if args.use_color:
        in_dim += 3
    if args.use_size:
        in_dim += 3
    if args.use_normals:
        in_dim += 3
    if args.use_height:
        in_dim += 1
    # Train classifier
    model = PointnetPPClass(num_classes=NUM_CLASSES, input_feature_dim=in_dim)
    model = train_classifier(model.to(args.device), data_loaders, args)
    eval_classifier(model, data_loaders['train'], args)


if __name__ == "__main__":
    # Parse arguments
    argparser = argparse.ArgumentParser()
    argparser.add_argument("--checkpoint_path", default="checkpoints/")
    argparser.add_argument("--checkpoint", default="classifier_scannet.pt")
    argparser.add_argument("--epochs", default=30, type=int)
    argparser.add_argument("--batch_size", default=64, type=int)
    argparser.add_argument("--lr", default=5e-4, type=float)
    argparser.add_argument("--wd", default=0, type=float)
    argparser.add_argument("--use_color", action="store_true")
    argparser.add_argument("--use_normals", action="store_true")
    argparser.add_argument("--use_height", action="store_true")
    argparser.add_argument("--use_size", action="store_true")
    argparser.add_argument("--overfit", action="store_true")
    argparser.add_argument("--eval", action="store_true")
    args = argparser.parse_args()
    args.ckpnt = osp.join(args.checkpoint_path, args.checkpoint)
    args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    os.makedirs(args.checkpoint_path, exist_ok=True)

    wandb.init(project="p++_classifier", name=args.checkpoint.split('.')[0])
    main(args)
