import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
from typing import Tuple

import torch
import torchvision.transforms as transforms
from torchvision.transforms import Compose

from xad.models.resnets.concept_resnets import ConceptResNet64
from xad.models.epsilon import EpsilonGenerator
from xad.models.resnets.resnet import WideResNet
from xad.main.bases import default_comment, main
from xad.models.bases import ADNN, ConditionalGenerator, ConditionalDiscriminator, ConceptNN


def modify_parser(parser):
    parser.set_defaults(
        comment='{obj}_imagenet_{OE}OE_{CFOE}_{XMTHD}_conclmb{XDCONC}_{NormCls}',
        objective='bce',
        dataset='imagenet',
        oe_dataset=['imagenet21k'],
        epochs=150,
        learning_rate=1e-3,
        weight_decay=0,
        milestones=[100, 125],
        batch_size=128,
        devices=[0, 1],
        classes=[2],
        iterations=2,
        x_discrete_anomaly_scores=1,
        x_batch_size=10,
        x_epochs=1,  # oe only in DiffEdit without a discrimnator
        x_learning_rate=2e-4,
        x_weight_decay=0,
        x_milestones=[0.3, 0.45, 0.6, 0.75, 0.9, ],
        x_milestone_alpha=0.5,
        x_method="diffedit",
        x_gen_every=2,
        x_disc_every=1,
        x_lamb_dist=1e-3,
        x_lamb_conc=0.5,
        x_lamb_gen=1e-1,
        x_mask_encode_strength=0.4,
        x_mask_thresholding_ratio=1.5,
        x_diffusion_inference_steps=40,
        x_diffusion_resolution=512,
    )


def get_transforms() -> Tuple[Compose, Compose]:
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        # transforms.ColorJitter(brightness=0.01, contrast=0.01, saturation=0.01, hue=0.01),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        # RandomGaussianNoise(0.001),
        # 'normalize'
    ])
    val_transform = Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        # 'normalize'
    ])
    return train_transform, val_transform


def get_models(args) -> Tuple[ADNN, ConditionalGenerator, ConceptNN]:
    assert args.objective not in ('dsvdd', 'mdsvdd'), 'bias not implemented for WideResNet'
    model = WideResNet(clf=args.objective in ('bce', 'focal'))
    gen = EpsilonGenerator(4 * (args.x_diffusion_resolution // 8) ** 2, torch.Size([args.x_concepts]))  # TODO is the latent dim correct?
    concept_classifier = ConceptResNet64(args.x_concepts)  # 64 is fine
    return model, gen, concept_classifier


if __name__ == '__main__':
    main(modify_parser, get_transforms, get_models)
