import lpips
import torch
import torch.nn.functional as F

from .base import Guidance
from classifiers.base import ClassifierBase

class DVCEDummyGuidance(Guidance):

    def __init__(self, model: ClassifierBase, model_2: ClassifierBase, uses_target_clf: bool):
        super(DVCEDummyGuidance, self).__init__()
        self.clf = model
        self.clf_2 = model_2
    
    def get_cond_module(self):
        pass
    
    def get_cond_fn(self):
        pass
    
def from_m1p1_to_01(x):
    is_m1 = (x.flatten(start_dim = 1).min(dim = 1)[0] < 0.).any()
    if is_m1:
        x = x - x.flatten(start_dim = 1).min(dim = 1)[0].view(-1, 1, 1, 1)
        x = x / x.flatten(start_dim = 1).max(dim = 1)[0].view(-1, 1, 1, 1)
    return x

def from_01_to_m1p1(x):
    is_0 = (x.flatten(start_dim = 1).min(dim = 1)[0] >= 0.).any()
    if is_0:
        x = (x - 0.5) * 2
    return x