import argparse
import datetime
import os

import torch
from torch.utils.data import DataLoader

from datasets.scanobject import ScanObject
from datasets.sncore_4k import ShapeNetCore4k
from datasets.sncore_splits import *
from loss import ChamferDistance
from models.foldingnet import SkipVariationalFoldingNet
from utils.common import set_random_seed, init_np_seed, save_results, rescale
from utils.realistic_projection import Realistic_Projection
from utils.eval_utils import *

device = 'mps'


def get_args():
    # Synth to Real Benchmark
    # Inference settings
    # Seeds used for our paper experiments are: 1 -> SN1, 41 -> SN2, 13718 -> SN3.
    # ShapeNetCore train_dataset length: SN1 -> 7342, SN2 -> 16381, SN3 -> 11581
    root = '/root/data/3D_OS-main/3D_OS_release_data'
    model_path = 'model_logs/vae_SN1.pt'
    parser = argparse.ArgumentParser(description='Evaluate VAE performance on ShapeNetCore or ScanObjectNN.')
    parser.add_argument('--data_root', type=str, default=root)
    parser.add_argument('--src', type=str, default='SN1', choices=['SN1', 'SN2', 'SN3'])
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--num_points', type=int, default=2048)
    parser.add_argument('--dim', type=int, default=512)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--model_path', type=str, default=model_path)
    args = parser.parse_args()
    args.data_root = os.path.expanduser(args.data_root)
    accepted_src = ['SN1', 'SN2', 'SN3']
    assert args.src in accepted_src, f"Chosen class set {args.src} is not correct"
    accepted_src.remove(args.src)
    args.tar1 = accepted_src[0]
    args.tar2 = accepted_src[1]
    return args


def get_sncore_test_loaders(opt):
    """
    Returns all dataloaders used for evaluation, this function is compatible with VAE after training.
    :returns test_loader: test ID data loader - no augmentation, shuffle=True, drop_last=False
             tar1_loader: test OOD 1 data loader - no augmentation, shuffle=True, drop_last=False
             tar2_loader: test OOD 2 data loader - no augmentation, shuffle=True, drop_last=False
    """
    drop_last = False
    base_data_params = {'data_root': opt.data_root, 'num_points': opt.num_points, 'transforms': None}
    print(f"OOD evaluation data - "
          f"src: {opt.src}, tar1: {opt.tar1}, tar2: {opt.tar2}")

    # in domain test data
    print(f"===> Creating {opt.src} dataset...")
    test_data = ShapeNetCore4k(**base_data_params, split='test', class_choice=list(eval(opt.src).keys()), pretrain=False)
    # targets (out of domain data) test data
    print(f"===> Creating {opt.tar1} dataset...")
    tar1_data = ShapeNetCore4k(**base_data_params, split='test', class_choice=list(eval(opt.tar1).keys()), pretrain=False)
    print(f"===> Creating {opt.tar2} dataset...")
    tar2_data = ShapeNetCore4k(**base_data_params, split='test', class_choice=list(eval(opt.tar2).keys()), pretrain=False)

    # loaders
    test_loader = DataLoader(test_data, batch_size=opt.batch_size, drop_last=drop_last, num_workers=opt.num_workers, worker_init_fn=init_np_seed)
    tar1_loader = DataLoader(tar1_data, batch_size=opt.batch_size, drop_last=drop_last, num_workers=opt.num_workers, worker_init_fn=init_np_seed)
    tar2_loader = DataLoader(tar2_data, batch_size=opt.batch_size, drop_last=drop_last, num_workers=opt.num_workers, worker_init_fn=init_np_seed)

    return test_loader, tar1_loader, tar2_loader


def get_sonn_test_loaders(opt):
    print(f"Arguments: {opt}")

    dataloader_config = {'batch_size': opt.batch_size, 'drop_last': False, 'shuffle': False,
                         'num_workers': opt.num_workers, 'sampler': None, 'worker_init_fn': init_np_seed}

    # whole evaluation is done on ScanObject RW data
    # if model is pretrained on ScanObject, then the 'split' should be 'test'
    sonn_args = {
        'data_root': opt.data_root,
        'sonn_split': "main_split",
        'h5_file': "objectdataset.h5",
        'split': 'all',  # we use both training (unused) and test samples during evaluation
        'num_points': opt.num_points,  # default: use all 2048 points to avoid sampling randomness
        'transforms': None,  # no augmentation applied at inference time
        'pretrain': False
    }

    if opt.src == 'SR1':
        print("Src is SR1\n")
        id_loader = DataLoader(ScanObject(class_choice="sonn_2_mdSet1", **sonn_args), **dataloader_config)
        ood1_loader = DataLoader(ScanObject(class_choice="sonn_2_mdSet2", **sonn_args), **dataloader_config)
    elif opt.src == 'SR2':
        print("Src is SR2\n")
        id_loader = DataLoader(ScanObject(class_choice="sonn_2_mdSet2", **sonn_args), **dataloader_config)
        ood1_loader = DataLoader(ScanObject(class_choice="sonn_2_mdSet1", **sonn_args), **dataloader_config)
    else:
        raise ValueError(f"OOD evaluation - wrong src: {opt.src}")

    # second ScanObjectNN out-of-distribution set is common to both SR1 and SR2 sources
    # these are the samples from ScanObjectNN categories with poor mapping to ModelNet categories
    ood2_loader = DataLoader(ScanObject(class_choice="sonn_ood_common", **sonn_args), **dataloader_config)
    return id_loader, ood1_loader, ood2_loader


def get_loaders_test(opt):
    """
    This function is used to get dataloader for Real to Real Benchmark.
    Only target loader and test loader.
    """
    loader_args_test = {'batch_size': opt.batch_size, 'drop_last': False, 'shuffle': False,
                        'num_workers': opt.num_workers, 'sampler': None, 'worker_init_fn': init_np_seed}

    sonn_args_test = {
        'data_root': opt.data_root,
        'sonn_split': "main_split",
        'h5_file': "objectdataset.h5",
        'num_points': opt.num_points,  # default: use all 2048 points to avoid sampling randomness
        'transforms': None,  # no augmentation applied at inference time
        'pretrain': False
    }
    test_data = ScanObject(split='test', class_choice=opt.src, **sonn_args_test)

    # target choice depending on opt.src
    if opt.src == "SR12":
        target_name = "sonn_ood_common"
    elif opt.src == "SR13":
        target_name = "sonn_2_mdSet2"
    elif opt.src == "SR23":
        target_name = "sonn_2_mdSet1"
    else:
        raise ValueError(f"Unknown source: {opt.src}")
    target_data = ScanObject(split='all', class_choice=target_name, **sonn_args_test)
    print(f"SRC: {opt.src}, Test - target: {target_name}")

    test_loader = DataLoader(test_data, **loader_args_test)
    target_loader = DataLoader(target_data, **loader_args_test)
    return test_loader, target_loader


@torch.no_grad()
def evaluate(model, test_loader, criterion):
    chamfer_scores = []
    model.eval()
    for pc in test_loader:
        pc = pc.to(device)
        folding2, _, _, _ = model(pc)
        for origin, recon in zip(pc, folding2):
            dist = criterion(origin.unsqueeze(0), recon.unsqueeze(0))
            chamfer_scores.append(dist)
    chamfer_scores = torch.stack(chamfer_scores, dim=0)
    return chamfer_scores


def eval_anomaly(args, model, src_label=0):
    """
    The paper reports AUROC 3 and FPR 3 which refer to the scenario (known) src -> unknown set 1 + unknown set 2.
    labels -> 0 for known set src, labels -> 1 for unknown sets tar1 and tar2.
    """
    tar_label = int(not src_label)
    print(f"AUROC - Src label: {src_label}, Tar label: {tar_label}")

    # Dataloader
    if args.src.startswith("SN"):
        src_loader, tar1_loader, tar2_loader = get_sncore_test_loaders(args)
    elif args.src.startswith("SR"):
        src_loader, tar1_loader, tar2_loader = get_sonn_test_loaders(args)
    else:
        raise ValueError(f"Anomaly evaluation - wrong src: {args.src}")

    # Init
    model = model.to(device)
    model.eval()
    criterion_cd = ChamferDistance().to(device)

    print("===> Evaluating model with Chamfer distance...")
    # source
    src_scores = evaluate(model, src_loader, criterion_cd)
    # target 1
    tar1_scores = evaluate(model, tar1_loader, criterion_cd)
    # target 2
    tar2_scores = evaluate(model, tar2_loader, criterion_cd)

    print("===> Computing Anomaly metrics with Chamfer distance...")
    # Src vs Tar 1
    res_tar1 = get_ood_metrics(src_scores, tar1_scores, src_label)
    # Src vs Tar 2
    res_tar2 = get_ood_metrics(src_scores, tar2_scores, src_label)
    # Src vs Tar 1 + Tar 2
    big_tar_scores = np.concatenate([to_numpy(tar1_scores), to_numpy(tar2_scores)], axis=0)
    res_big_tar = get_ood_metrics(src_scores, big_tar_scores, src_label)

    # N.B. As using Chamfer distance as anomaly score, thus label 0 for normal samples.
    # get_ood_metrics reports inverted AUROC and other results
    # the ood_metrics library argue to use

    print_ood_output(res_tar1, res_tar2, res_big_tar)
    print(f"to spreadsheet: "
          f"{res_tar1['auroc']},{res_tar1['fpr_at_95_tpr']},{res_tar1['aupr_in']},{res_tar1['aupr_out']},"
          f"{res_tar2['auroc']},{res_tar2['fpr_at_95_tpr']},{res_tar2['aupr_in']},{res_tar2['aupr_out']},"
          f"{res_big_tar['auroc']},{res_big_tar['fpr_at_95_tpr']},{res_big_tar['aupr_in']},{res_big_tar['aupr_out']}")

    return res_tar1, res_tar2, res_big_tar


def eval_real2real(args, model, model_path, src_label=0):
    tar_label = int(not src_label)
    print(f"AUROC - Src label: {src_label}, Tar label: {tar_label}")

    # Dataloader
    test_loader, target_loader = get_loaders_test(args)

    # Init
    model = model.to(device)
    model.eval()
    criterion_cd = ChamferDistance().to(device)

    print("===> Evaluating model with Chamfer distance...")
    # source
    src_scores = evaluate(model, test_loader, criterion_cd)
    # target
    tar_scores = evaluate(model, target_loader, criterion_cd)

    # N.B. As using Chamfer distance as anomaly score, thus label 0 for normal samples.
    # get_ood_metrics reports inverted AUROC and other results
    # the ood_metrics library argue to use
    print("===> Computing Anomaly metrics with Chamfer distance...")
    res = get_ood_metrics(src_scores, tar_scores, src_label)
    auroc, fpr, aupr_in, aupr_out = res['auroc'], res['fpr_at_95_tpr'], res['aupr_in'], res['aupr_out']
    print(f"SRC->TAR:      AUROC: {auroc:.4f}, FPR95: {fpr:.4f}, AUPR_IN: {aupr_in:.4f}, AUPR_OUT: {aupr_out:.4f}")
    with open('log.txt', 'a+', newline='\n') as f:
        f.write(f"SRC: {args.src}, infer model: {model_path}\n")
        f.write(f"log time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"SRC->TAR:  AUROC: {auroc:.4f}, FPR95: {fpr:.4f}, AUPR_IN: {aupr_in:.4f}, AUPR_OUT: {aupr_out:.4f}\n")
        f.write("\n")
    f.close()


@torch.no_grad()
def evaluate_with_depth(model, test_loader):
    model.eval()
    criterion_cd = ChamferDistance().to(device)
    criterion_mse = torch.nn.MSELoss().to(device)
    projector = Realistic_Projection(device)
    chamfer_scores, mse_scores = [], []
    for pc in test_loader:
        pc = pc.to(device)
        folding2, _, _, _ = model(pc)
        for origin, recon in zip(pc, folding2):
            dist = criterion_cd(origin.unsqueeze(0), recon.unsqueeze(0))
            chamfer_scores.append(dist)
            origin_depth = projector.get_img(origin.unsqueeze(0))
            recon_depth = projector.get_img(recon.unsqueeze(0))
            mse = criterion_mse(origin_depth, recon_depth)
            mse_scores.append(mse)
    chamfer_scores = torch.stack(chamfer_scores, dim=0)
    mse_scores = torch.stack(mse_scores, dim=0)
    mse_rescaled = rescale(mse_scores, min(chamfer_scores), max(chamfer_scores))
    return chamfer_scores, mse_rescaled


def eval_anomaly_with_depth(args, model, src_label=0, w=1.0):
    """
    The paper reports AUROC 3 and FPR 3 which refer to the scenario (known) src -> unknown set 1 + unknown set 2.
    labels -> 0 for known set src, labels -> 1 for unknown sets tar1 and tar2.
    """
    tar_label = int(not src_label)
    print(f"AUROC - Src label: {src_label}, Tar label: {tar_label}")

    # Dataloader
    if args.src.startswith("SN"):
        src_loader, tar1_loader, tar2_loader = get_sncore_test_loaders(args)
    elif args.src.startswith("SR"):
        src_loader, tar1_loader, tar2_loader = get_sonn_test_loaders(args)
    else:
        raise ValueError(f"Anomaly evaluation - wrong src: {args.src}")

    # Init
    model = model.to(device)
    model.eval()
    w1, w2 = round(w, 2), round(1 - w, 2)

    print("===> Evaluating model with Chamfer distance and MSE...")
    print(f"===> Weight for Chamfer distance is {w1}.")
    print(f"===> Weight for MSE is {w2}.")
    # source
    src_scores1, src_scores2 = evaluate_with_depth(model, src_loader)
    src_scores = w * src_scores1 + (1 - w) * src_scores2
    # target 1
    tar1_scores1, tar1_scores2 = evaluate_with_depth(model, tar1_loader)
    tar1_scores = w * tar1_scores1 + (1 - w) * tar1_scores2
    # target 2
    tar2_scores1, tar2_scores2 = evaluate_with_depth(model, tar2_loader)
    tar2_scores = w * tar2_scores1 + (1 - w) * tar2_scores2

    print("===> Computing Anomaly metrics with Chamfer distance and MSE...")
    # Src vs Tar 1
    res_tar1 = get_ood_metrics(src_scores, tar1_scores, src_label)
    # Src vs Tar 2
    res_tar2 = get_ood_metrics(src_scores, tar2_scores, src_label)
    # Src vs Tar 1 + Tar 2
    big_tar_scores = np.concatenate([to_numpy(tar1_scores), to_numpy(tar2_scores)], axis=0)
    res_big_tar = get_ood_metrics(src_scores, big_tar_scores, src_label)

    # N.B. As using reconstruction error as anomaly score, thus label 0 for normal samples.
    # get_ood_metrics reports inverted AUROC and other results
    # the ood_metrics library argue to use

    print_ood_output(res_tar1, res_tar2, res_big_tar)
    print(f"to spreadsheet: "
          f"{res_tar1['auroc']},{res_tar1['fpr_at_95_tpr']},{res_tar1['aupr_in']},{res_tar1['aupr_out']},"
          f"{res_tar2['auroc']},{res_tar2['fpr_at_95_tpr']},{res_tar2['aupr_in']},{res_tar2['aupr_out']},"
          f"{res_big_tar['auroc']},{res_big_tar['fpr_at_95_tpr']},{res_big_tar['aupr_in']},{res_big_tar['aupr_out']}")

    return res_tar1, res_tar2, res_big_tar


def eval_real2real_with_depth(args, model, model_path, src_label=0, w=1.0):
    tar_label = int(not src_label)
    print(f"AUROC - Src label: {src_label}, Tar label: {tar_label}")

    # Dataloader
    test_loader, target_loader = get_loaders_test(args)

    # Init
    model = model.to(device)
    model.eval()
    w1, w2 = round(w, 2), round(1 - w, 2)

    print("===> Evaluating model with Chamfer distance...")
    print(f"===> Weight for Chamfer distance is {w1}.")
    print(f"===> Weight for MSE is {w2}.")
    # source
    src_scores1, src_scores2 = evaluate_with_depth(model, test_loader)
    src_scores = w1 * src_scores1 + w2 * src_scores2
    # target
    tar_scores1, tar_scores2 = evaluate_with_depth(model, target_loader)
    tar_scores = w1 * tar_scores1 + w2 * tar_scores2

    # N.B. As using Chamfer distance as anomaly score, thus label 0 for normal samples.
    # get_ood_metrics reports inverted AUROC and other results
    # the ood_metrics library argue to use
    print("===> Computing Anomaly metrics with Chamfer distance...")
    res = get_ood_metrics(src_scores, tar_scores, src_label)
    auroc, fpr, aupr_in, aupr_out = res['auroc'], res['fpr_at_95_tpr'], res['aupr_in'], res['aupr_out']
    print(f"SRC->TAR:      AUROC: {auroc:.4f}, FPR95: {fpr:.4f}, AUPR_IN: {aupr_in:.4f}, AUPR_OUT: {aupr_out:.4f}")
    with open('log.txt', 'a+', newline='\n') as f:
        f.write(f"SRC: {args.src}, infer model: {model_path}\n")
        f.write(f"log time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"SRC->TAR:  AUROC: {auroc:.4f}, FPR95: {fpr:.4f}, AUPR_IN: {aupr_in:.4f}, AUPR_OUT: {aupr_out:.4f}\n")
        f.write("\n")
    f.close()


def infer(args):
    set_random_seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    print("===> Loading model...")
    model = SkipVariationalFoldingNet(args.dim).to(device)
    model.load_state_dict(torch.load(args.model_path, map_location=device))

    args.src = "SR12"
    eval_real2real(args, model, args.model_path)
    res_tar1, res_tar2, res_big_tar = eval_anomaly(args, model)
    save_results(res_tar1, res_tar2, res_big_tar, args.src, args.model_path)


if __name__ == '__main__':
    args = get_args()
    infer(args)
