import gc
import os
import os.path as pt
import traceback
from abc import ABC
from copy import deepcopy
from typing import Callable, List, Mapping, Tuple
from datetime import datetime

import numpy as np
import shutil
import sklearn
import torch.nn
import torchvision.datasets
from sklearn.metrics import accuracy_score
from torch.nn import Module
from torch.nn.functional import cross_entropy
from torch.optim.lr_scheduler import _LRScheduler
from torchvision.datasets.folder import default_loader
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm

from xad.counterfactual.eval import get_roc, compute_fid_scores
from xad.datasets.bases import CombinedDataset, TorchvisionDataset
from xad.models.bases import ConceptNN, ConditionalDiscriminator, ConditionalGenerator
from xad.utils.logger import Logger
from xad.utils.training_tools import NanGradientsError, int_set_to_str, weight_reset
from xad.utils.data_tools import random_split_tensor



def huber_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return ((x - y) ** 2 + 1).sqrt() - 1


def hinge_disc_loss(logits: torch.Tensor, is_real: torch.Tensor) -> torch.Tensor:
    flip = is_real.mul(2).sub(1)
    return (-1.0 + logits.mul(flip)).clip(max=0).mean().mul(-1)


class XTrainer(ABC):
    pass