import argparse
import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
from typing import Tuple

import torch
import torchvision.transforms as transforms
from torchvision.transforms import Compose

from xad.models.cnn import GenericCNN28
from xad.models.resnets.concept_resnets import ConceptResNet64
from xad.models.resnets.resgan import WideResNetGenerator32, \
    WideSNResNetProjectionDiscriminator32
from xad.main.bases import default_comment, main
from xad.models.bases import ADNN, ConditionalGenerator, ConditionalDiscriminator, ConceptNN
from xad.datasets import COLORED_MNIST_CSETS


def modify_parser(parser):
    parser.add_argument('--cnn28-conv-layers', nargs='+', type=int, default=(16, 32, 64))
    parser.add_argument('--cnn28-linear-layers', nargs='+', type=int, default=(64, 32))
    parser.add_argument('--cnn28-ksize', type=int, default=5)
    parser.set_defaults(
        comment=default_comment(__file__),
        objective='bce',
        dataset='coloredmnist',
        oe_dataset=['coloredemnist'],
        epochs=120,
        learning_rate=5e-5,
        weight_decay=1e-6,
        milestones=[100, 150],
        batch_size=128,
        devices=[0],
        classes=['red+or+one', ],
        iterations=2,
        x_discrete_anomaly_scores=3,
        # x_normal_training_only=True,
        x_batch_size=64,
        x_epochs=350,
        x_learning_rate=2e-4,
        x_milestones=[300, 325],
        x_lamb_gen=1,
        x_lamb_asc=1,
        x_lamb_cyc=100,
        x_lamb_conc=10,
        x_gen_every=5,
        # x_cluster_ncc=True,
    )


def modify_args(args: argparse.Namespace):
    args.comment = args.comment.replace("{NormCls}", "Norm-" + "-".join(args.classes))
    args.classes = [
        cset_str if cset_str not in COLORED_MNIST_CSETS else
        "+".join((str(class_id) for class_id in sorted(COLORED_MNIST_CSETS[cset_str])))
        for cset_str in args.classes
    ]
    if len(args.oe_dataset) > 0 and args.oe_dataset[0] == 'coloredmnist' and isinstance(args.oe_classes, (list, tuple)):
        args.oe_classes = [
            cset_str if cset_str not in COLORED_MNIST_CSETS else
            "+".join((str(class_id) for class_id in sorted(COLORED_MNIST_CSETS[cset_str])))
            for cset_str in args.oe_classes
        ]


def get_transforms() -> Tuple[Compose, Compose]:
    train_transform = transforms.Compose([])
    val_transform = Compose([])
    return train_transform, val_transform


def get_models(args) -> Tuple[ADNN, ConditionalGenerator, ConditionalDiscriminator, ConceptNN]:
    model = GenericCNN28(
        conv_layer=args.cnn28_conv_layers, fc_layer=args.cnn28_linear_layers, ksize=args.cnn28_ksize,
        bias=args.objective not in ('dsvdd', 'mdsvdd'), clf=args.objective in ('bce', 'focal'),
        grayscale=False
    )
    gen = WideResNetGenerator32(1024, torch.Size([args.x_discrete_anomaly_scores, args.x_concepts]), grayscale=False)
    disc = WideSNResNetProjectionDiscriminator32(1024, torch.Size([args.x_discrete_anomaly_scores]), grayscale=False)
    concept_classifier = ConceptResNet64(args.x_concepts, grayscale=False)  # 64 is fine
    return model, gen, disc, concept_classifier


if __name__ == '__main__':
    main(modify_parser, get_transforms, get_models, modify_args)

