import argparse
import os
import random

import numpy as np
import torch

import dataloaders
from models import detectors, losses, trainers, utils


class Config(argparse.Namespace):
    dataset: str
    normal_class: str
    unseen_anomaly: str
    algorithm: str
    alpha: float
    n_epoch: int
    learning_rate: float
    batch_size: int
    seed: int


if __name__ == "__main__":
    # Parser
    parser = argparse.ArgumentParser()
    parser.add_argument("--algorithm", type=str, default="AE")
    parser.add_argument("--alpha", type=float, default=0.1)
    parser.add_argument("--n_epoch", type=int, default=100)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    config = Config(
        dataset="MVTec",
        normal_class="good",
        unseen_anomaly="top",
        algorithm=args.algorithm,
        alpha=args.alpha if args.algorithm in ["PU", "PUAE", "PUSVDD", "LOE", "SOEL"] else 0.0,
        n_epoch=args.n_epoch,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        seed=args.seed,
    )

    print("Setting:", config)

    # Key
    key = "_".join([str(v) for v in vars(config).values()])

    # Seed
    torch.manual_seed(config.seed)
    random.seed(config.seed)
    np.random.seed(config.seed)

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    # Dataset
    train_loader, valid_loader, test_loader, test_seen_loader, test_unseen_loader = dataloaders.load_mvtec(
        batch_size=128
    )

    # Output Directories
    for path in ["checkpoints", "results", "images"]:
        os.makedirs(path, exist_ok=True)

    # Deep Model
    model = detectors.ResidualSVDD(n_latent=128, use_affine=False)
    model = model.to(device)
    print("Model:", type(model))

    # Criterion
    criterion = losses.load(algorithm=config.algorithm, alpha=config.alpha)

    # Train
    if isinstance(model, detectors.DeepSVDD):
        # Set center for DeepSVDD model
        print("Set center for DeepSVDD:")
        utils.set_center(model=model, train_loader=train_loader, device=device, eps=0.1)

    print("Training:")
    trainer = trainers.Trainer(device=device)
    trainer.fit(
        model=model,
        criterion=criterion,
        train_loader=train_loader,
        valid_loader=valid_loader,
        checkpoint=f"checkpoints/{key}.pt",
        n_epoch=config.n_epoch,
        learning_rate=config.learning_rate,
        weight_decay=1e-3,
    )

    # Test
    test_score = utils.compute_auc(model=model, test_loader=test_loader, device=device)
    seen_anomaly_score = utils.compute_auc(model=model, test_loader=test_seen_loader, device=device)
    unseen_anomaly_score = utils.compute_auc(model=model, test_loader=test_unseen_loader, device=device)
    print("AUC:", test_score)
    print("AUC (seen):", seen_anomaly_score)
    print("AUC (unseen):", unseen_anomaly_score)

    result = utils.Result(
        train_loss=trainer.train_losses,
        valid_loss=trainer.valid_losses,
        test_score=test_score,
        seen_anomaly_score=seen_anomaly_score,
        unseen_anomaly_score=unseen_anomaly_score,
    )

    # Save and Plot Results
    utils.save_result(result=result, file_name=f"results/{key}.json")
    utils.plot_result(result=result, file_name=f"images/{key}.pdf")
