import os
import sys
sys.path.append('..')

import numpy as np
from PIL import Image
import torch
from torch import nn
import torchvision
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import cv2

import data
from config import opt
import classifier
from setup import load_model_weights
from simclr import SimCLR

classifier_dict = {
    'alexnet': classifier.AlexNet,
    'resnet18': classifier.ResNet18,
    'resnet34': classifier.ResNet34,
    'resnet50': classifier.ResNet50,
    'inceptionv3': classifier.InceptionV3,
    'vgg': classifier.VGG,
    'densenet': classifier.DenseNet,
    'pyramidnet': classifier.PyramidNet,
    'resnext': classifier.ResNeXt,
    'wrn': classifier.WideResNet,
    'eweresnet50': classifier.EWEResNet50,
    'plainresnet50': classifier.PlainResNet50,
    'mobilenetv1': classifier.MobileNetV1,
    'mobilenetv2': classifier.MobileNetV2,
}

dataset_dict = {
    'cifar10': data.CIFAR10,
    'cifar100': data.CIFAR100,
    'TinyImageNet': data.TinyImageNet,
    'noise': data.Noise,
    'svhn': data.SVHN,
    'ImageNet': data.ImageNet,
}

normalize_dict = {
    'cifar10':
        {'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010)},
    'cifar100':
        {'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762)},
    'TinyImageNet':
        {'mean': (0.5071, 0.4865, 0.4409), 'std': (0.2673, 0.2564, 0.2762)},
    'ImageNet':
        {'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225)}
}


def weight_init(model):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            # nn.init.constant_(m.weight, 1e-2)
            # nn.init.xavier_normal_(m.weight)
            # nn.init.constant_(m.bias,0)
            nn.init.normal_(m.weight, mean=0, std=1)
            nn.init.normal_(m.bias, 0)
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            # nn.init.kaiming_normal(m.weight, mode="fan_out")
            # nn.init.constant_(m.weight, 1e-3)
            nn.init.normal_(m.weight, mean=0, std=1)
            # nn.init.normal_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            # nn.init.constant_(m.weight, 2e-1)
            # nn.init.constant_(m.bias, 0)
            nn.init.normal_(m.weight, mean=0, std=1)
            nn.init.normal_(m.bias, 0)

def tensor2im(input_image, mean, std, imtype=np.float32):
    """"
    Parameters:
        input_image (tensor) --  input tensor
        imtype (type)        --  converted numpy data type
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        image_numpy = image_tensor.cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        for i in range(len(mean)):
            image_numpy[i] = image_numpy[i] * std[i] + mean[i]
        if imtype is np.float32:
            image_numpy = image_numpy #* 255
        else:
            image_numpy = image_numpy * 255
        image_numpy = np.transpose(image_numpy, (1, 2, 0))  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)

# def get_cam_img(model,target_layer,img_path,target_category=None):
#     img = cv2.imread(img_path)
#     rgb_img = cv2.resize(img, (112, 112))
#     rgb_img = np.float32(rgb_img) / 255
#     input_tensor = cam_dataloader(img_path)
#     cam = eval(cam_type)(model=model, target_layer=target_layer, use_cuda=True)
#     grayscale_cam = cam(input_tensor=input_tensor, target_category=target_category)
#     grayscale_cam = grayscale_cam[0, :]
#     cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
#     cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
#     return cam_image


def prepare_dataset(dataset):
    dataset_parse = dataset.split('-')
    dataset_name = dataset_parse[0]
    partition = dataset_parse[1] if len(dataset_parse)>1 else None
    dataset_obj = dataset_dict[dataset_name](input_size=32, partition=partition)

    return dataset_obj


def prepare_rand_init_model(model_name, n_classes=10):
    if model_name.startswith('efficientnet'):
        model = EfficientNet.from_pretrained(
            model_name,
            num_classes=n_classes)
    elif model_name.startswith('wrn'):
        depth = int(model_name.split('-')[-1])
        model = classifier.WideResNet(
            n_outputs=n_classes,
            depth=depth
        )
    else:
        model = classifier_dict[model_name](
            n_outputs=n_classes
        )
    model.apply(weight_init)
    
    return model


def prepare_victim_model(victim_model, victim_dataset, victim_n_classes=10):
    if victim_model.startswith('efficientnet'):
        victim = EfficientNet.from_pretrained(
            victim_model,
            num_classes=victim_n_classes)
    elif victim_model.startswith('wrn'):
        depth = int(victim_model.split('-')[-1])
        victim = classifier.WideResNet(
            n_outputs=victim_n_classes,
            depth=depth
        )
    else:
        victim = classifier_dict[victim_model](
            n_outputs=victim_n_classes
        )
    model_name = 'victim_%s_%s' % (victim_model, victim_dataset)
    ckpt_path = f'[HOME_DIR]/substitute_attack_demo/checkpoints/{model_name}_state_dict'
    victim.load_state_dict(torch.load(ckpt_path))

    return victim


def prepare_simclr_model(sub_model, sub_dataset, victim_n_classes=10, pre_sub_n_classes=100):
    if sub_model.startswith('efficientnet'):
        substitute = EfficientNet.from_pretrained(
            sub_model,
            num_classes=victim_n_classes)
    elif sub_model.startswith('wrn'):
        depth = int(sub_model.split('-')[-1])
        substitute = classifier.WideResNet(
            n_outputs=victim_n_classes,
            depth=depth
        )
    else:
        substitute = classifier_dict[sub_model](
            n_outputs=victim_n_classes
        )

    # load simclr model
    if sub_model.startswith('efficientnet'):
        encoder = EfficientNet.from_pretrained(
            sub_model,
            num_classes=victim_n_classes)
    elif sub_model.startswith('wrn'):
        depth = int(sub_model.split('-')[-1])
        encoder = classifier.WideResNet(
            n_outputs=victim_n_classes,
            depth=depth
        )
    else:
        encoder = classifier_dict[sub_model](
            n_outputs=victim_n_classes
        )
    n_features = encoder.fc.in_features
    projection_dim = 64
    model = SimCLR(encoder, projection_dim, n_features)
    model_name = 'simclr_%s_%s' % (sub_model, sub_dataset)

    ckpt_path = f'[HOME_DIR]/substitute_attack_demo/checkpoints/{model_name}_state_dict'
    model.load_state_dict(torch.load(ckpt_path))
    
    pre_sub = model.encoder


    substitute = load_model_weights(pre_sub, substitute)
    substitute.fc = nn.Linear(substitute.fc.in_features, opt.victim_n_classes)
    
    return substitute

def compare_victim_simclr_cam(save_dir, model, victim_dataset, sub_dataset, vis_method='GradCAM'):
    if 'cifar' in victim_dataset:
        if '100' in victim_dataset:
            n_classes = 100
            pre_sub_n_classes=10
        else:
            n_classes = 10
            pre_sub_n_classes=100
    else:
        n_classes = 10
        pre_sub_n_classes=100

    
    victim_model = prepare_victim_model(model, victim_dataset, n_classes)
    rand_init_model = prepare_rand_init_model(model, n_classes)
    rand_init_model.fc.apply(weight_init)
    simclr_model = prepare_simclr_model(model, sub_dataset, n_classes, pre_sub_n_classes)
    simclr_model.fc.apply(weight_init)
    # rand_init_model.fc = load_model_weights(victim_model.fc, rand_init_model.fc)
    # simclr_model.fc = load_model_weights(victim_model.fc, rand_init_model.fc)

    dataset_obj = prepare_dataset(victim_dataset)
    dataloader = dataset_obj.test_dataloader()

    mean = normalize_dict[victim_dataset]['mean']
    std = normalize_dict[victim_dataset]['std']

    batch_size = 100
    n_batch = 5

    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            rgb_img = tensor2im(img, mean, std, np.uint8)
            rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
            save_name = f'{idx}_original_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, rgb_img)

    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        target_layer = victim_model.layer4[-1]
        cam = eval(vis_method)(model=victim_model, target_layer=target_layer, use_cuda=True)
        grayscale_cams = cam(input_tensor=imgs, target_category=targets)
        print(grayscale_cams.shape)
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            grayscale_cam = grayscale_cams[j]
            rgb_img = tensor2im(img, mean, std)
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            save_name = f'{idx}_victim_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, cam_image)
    
    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        target_layer = rand_init_model.layer4[-1]
        cam = eval(vis_method)(model=rand_init_model, target_layer=target_layer, use_cuda=True)
        grayscale_cams = cam(input_tensor=imgs, target_category=targets)
        print(grayscale_cams.shape)
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            grayscale_cam = grayscale_cams[j]
            rgb_img = tensor2im(img, mean, std)
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            save_name = f'{idx}_rand_init_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, cam_image)

    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        target_layer = simclr_model.layer4[-1]
        cam = eval(vis_method)(model=simclr_model, target_layer=target_layer, use_cuda=True)
        grayscale_cams = cam(input_tensor=imgs, target_category=targets)
        print(grayscale_cams.shape)
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            grayscale_cam = grayscale_cams[j]
            rgb_img = tensor2im(img, mean, std)
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            save_name = f'{idx}_simclr_{sub_dataset}_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, cam_image)

    simclr_model = prepare_simclr_model(model, 'TinyImageNet', n_classes, pre_sub_n_classes)
    simclr_model.fc.apply(weight_init)

    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        target_layer = simclr_model.layer4[-1]
        cam = eval(vis_method)(model=simclr_model, target_layer=target_layer, use_cuda=True)
        grayscale_cams = cam(input_tensor=imgs, target_category=targets)
        print(grayscale_cams.shape)
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            grayscale_cam = grayscale_cams[j]
            rgb_img = tensor2im(img, mean, std)
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            save_name = f'{idx}_simclr_TinyImageNet_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, cam_image)

    simclr_model = prepare_simclr_model(model, 'ImageNet', n_classes, pre_sub_n_classes)
    simclr_model.fc.apply(weight_init)

    for i, (imgs, targets) in enumerate(dataloader):
        if i == n_batch:
            break
        target_layer = simclr_model.layer4[-1]
        cam = eval(vis_method)(model=simclr_model, target_layer=target_layer, use_cuda=True)
        grayscale_cams = cam(input_tensor=imgs, target_category=targets)
        print(grayscale_cams.shape)
        for j in range(batch_size):
            idx = i*batch_size + j
            img = imgs[j]
            target = targets[j]
            grayscale_cam = grayscale_cams[j]
            rgb_img = tensor2im(img, mean, std)
            cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
            save_name = f'{idx}_simclr_ImageNet_{target}.png'
            save_path = os.path.join(save_dir, save_name)
            cv2.imwrite(save_path, cam_image)

if __name__ == '__main__':
    compare_victim_simclr_cam(save_dir='[HOME_DIR]/substitute_attack_demo/results/cam/', model='resnet34', 
    victim_dataset='cifar10', sub_dataset='cifar100',
    vis_method='GradCAM')
    