import argparse
import random

import matplotlib.pyplot as plt
import numpy as np
from sklearn.ensemble import IsolationForest
from sklearn.metrics import roc_auc_score
from sklearn.svm import OneClassSVM

import dataloaders


class Config(argparse.Namespace):
    algorithm: str
    seed: int


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("--algorithm", type=str, default="IF")
    parser.add_argument("--seed", type=int, default=42)
    config = parser.parse_args(namespace=Config())
    print("Setting:", config)

    # Seed
    random.seed(config.seed)
    np.random.seed(config.seed)

    # Dataset
    batch_size = 128
    train_loader = dataloaders.make_toy_data(batch_size=batch_size, is_train=True)
    valid_loader = dataloaders.make_toy_data(batch_size=batch_size, is_train=True)
    test_loader = dataloaders.make_toy_data(batch_size=batch_size, is_train=False)

    # Model
    clf = IsolationForest(random_state=config.seed) if config.algorithm == "IF" else OneClassSVM(gamma="auto")

    # Train
    X_train, Y_train = dataloaders.to_numpy(train_loader)
    X_test, Y_test = dataloaders.to_numpy(test_loader)
    clf.fit(X_train[Y_train == 0])

    # Test
    Y_predict = 1 - clf.decision_function(X_test)
    test_score = roc_auc_score(y_true=Y_test, y_score=Y_predict)
    print("AUC:", test_score)

    # Plot
    plt.rcParams.update(plt.rcParamsDefault)
    plt.rcParams.update(
        {
            "xtick.labelsize": 18,
            "ytick.labelsize": 18,
            "axes.labelsize": 18,
            "legend.fontsize": 18,
            "ps.useafm": True,
            "pdf.use14corefonts": True,
            "text.usetex": True,
            "font.family": "Times New Roman",
        }
    )

    colors = [
        "#1f77b4",
        "#ff7f0e",
        "#2ca02c",
        "#d62728",
        "#9467bd",
        "#8c564b",
        "#e377c2",
        "#7f7f7f",
        "#bcbd22",
        "#17becf",
    ]

    xx = np.linspace(-7.5, 7.5, 100)
    yy = np.linspace(-7.5, 7.5, 100)
    XX, YY = np.meshgrid(xx, yy)
    PP = np.dstack((XX, YY)).astype(np.float32).reshape(-1, 2)
    ZZ = 1 - clf.decision_function(PP).reshape(100, 100)

    plt.clf()
    fig, ax = plt.subplots()
    mappable = ax.contourf(XX, YY, ZZ, levels=20, cmap="viridis")
    fig.colorbar(mappable)

    X, y = train_loader.dataset.tensors  # type: ignore
    X = X.numpy()
    y = y.numpy()

    plt.scatter(X[y == 0][:, 0], X[y == 0][:, 1], color=colors[1], label="Unlabeled")
    plt.scatter(X[y == 1][:, 0], X[y == 1][:, 1], color=colors[0], label="Anomaly")

    plt.show()
