import torch
import torchvision
import torchvision.transforms as transforms

from collections import OrderedDict
import os
import sys
import numpy as np
sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))

class GammaMapping:
    def __init__(self, gamma):
        self.gamma = gamma

    def __call__(self, img):
        img = img.to(torch.float32)
        img = img.pow(self.gamma)
        return img

def robustbench_load_model(method, dataset):
    method_mapping = {'trades': 'Zhang2019Theoretically', 'mart':'Wang2020Improving', 'awp':'Wu2020Adversarial_extra', 'hat':'Rade2021Helper_extra'}
    model = robustbench_utils.load_model(model_name=method_mapping[method], 
                       dataset=dataset, 
                       threat_model='Linf')
    return model

@staticmethod
def sample_latent(latent_r, lambda_r):
    eps = torch.normal(0, 1, size=lambda_r.size()).cuda()
    return latent_r + lambda_r.mul(eps)

def find_features_wip(model, robust_index, inputs, labels, pop_number, forward_version=False):
    latent_r = model(inputs, intermediate_propagate=0, pop=4)
    lambda_r = torch.zeros([*latent_r.size()[:2],1,1]).cuda().requires_grad_()

    if isinstance(robust_index, np.ndarray):
        robust_index = torch.from_numpy(robust_index).cuda()
        robust_index = robust_index.view(1, 2048, 1, 1)

    non_robust_index = 1-robust_index

    robust_latent_z     = latent_r * robust_index
    non_robust_latent_z = latent_r * non_robust_index

    orig_outputs = model(inputs, intermediate_propagate=0, pop=0).detach()
    _, orig_predicted = orig_outputs.max(1)

    robust_outputs = model(robust_latent_z.clone(), intermediate_propagate=pop_number, pop=0).detach()
    _, robust_predicted = robust_outputs.max(1)

    non_robust_outputs = model(non_robust_latent_z.clone(), intermediate_propagate=pop_number, pop=0).detach()
    _, non_robust_predicted = non_robust_outputs.max(1)

    return orig_predicted, robust_predicted, non_robust_predicted