import torch
# from torchvision import models
from pytorch_cam import GradCAM, \
    SimCAM

# import os
# os.environ['TORCH_HOME'] = './'

methods = \
    {"gradcam": GradCAM,
     "simcam": SimCAM}


class CamModel():
    def __init__(self, model, layer_mix, n_class, cam_method='gradcam'):
        super(CamModel, self).__init__()
        print(f'------Using {cam_method} as the saliency detector.---------')
        self.model = model
        self.layer_mix = layer_mix
        self.cam_method = cam_method
        self.target_tensor_shape = None
        self.n_class = n_class

        if model.__class__.__name__ == 'ResNet':
            if hasattr(model, 'layer4'):
                target_layers = [model.layer4[-1]]
            else:
                target_layers = [model.layer3[-1]]
        elif model.__class__.__name__ == 'CifarResNet':
            if hasattr(model, 'layer4'):
                target_layers = [model.layer4[-1]]
            else:
                target_layers = [model.layer3[-1]]
        elif model.__class__.__name__ == 'WideResNet':
            target_layers = [model.layer3[-1]]
        elif model.__class__.__name__ == 'PreActResNet':
            target_layers = [model.layer4[-1]]
        elif model.__class__.__name__ == 'DataParallel':
            if model.module.__class__.__name__ == 'PreActResNet':
                target_layers = [model.module.layer4[-1]]
            elif model.module.__class__.__name__ in ['CifarResNet', 'ResNet']:
                target_layers = [model.module.layer4[-1]]

        self.cam = methods[cam_method](
            model=self.model, target_layers=target_layers, use_cuda=True)
        self.cam.batch_size = 128

    def get_cam(self, input, target, aug_smooth=False, eigen_smooth=False, mixing=True, adv=False, s_size=None):
        if s_size is None:
            s_size = self.cam.get_target_width_height(input)

        if adv:
            cam_target = torch.randint(low=0, high=self.n_class, size=target.size())
        else:
            cam_target = target

        if self.cam_method == 'saliency':
            grayscale_cam = self.cam(input_tensor=input,
                                     target_category=cam_target)
        else:
            grayscale_cam = self.cam(input_tensor=input,
                                     target_category=cam_target,
                                     aug_smooth=aug_smooth,
                                     eigen_smooth=eigen_smooth,
                                     mixing=mixing,  # mixing will decide wheather the channel-wise cam are summed
                                     target_tensor_shape=s_size)

        return torch.tensor(grayscale_cam)


def test():
    pass

if __name__ == "__main__":
    test()
