from easydict import EasyDict
import torch
from torch import nn

import fling.dataset
from fling.utils.attack_utils import DLGAttacker
from fling.utils.registry_utils import DATASET_REGISTRY
from fling.utils.registry_utils import MODEL_REGISTRY


class ToyModel(nn.Module):
    """
    Overview:
        A toy model for demonstrating attacking results.
    """

    def __init__(self):
        super(ToyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()

        self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()

        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flat = nn.Flatten()

        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        x = self.relu3(self.conv3(x))
        x = self.flat(self.pool(x))
        return self.fc(x)


def attack_demo(test_dataset, img_idx, path_head, parameter_args):
    # Step 2: prepare the model.
    model_arg=EasyDict(dict(
        name='resnet8',
        input_channel=3,
        class_number=100,
    ))
    model_name = model_arg.pop('name')
    model = MODEL_REGISTRY.build(model_name, **model_arg)
    model.load_state_dict(torch.load(path_head + '/before_model.ckpt'))

    # Step 3: initialize the attacker.
    layer_name = parameter_args["name"] if parameter_args["name"] == "all" else parameter_args["keywords"]
    attacker = DLGAttacker(iteration=3000, working_dir=path_head + '/dlg_attacker/' + layer_name + '/' + str(img_idx),
                           iteration_per_save=100, distance_measure='euclid', tv_weight=5000)

    # Step 4: attack.
    attacker.attack(model, test_dataset, device='cuda', class_number=100, parameter_args=parameter_args,
                    save_img=True, optim_backend='adam', img_idx_in=img_idx)

    # If you want to use ``lbfgs`` as optim backend, you can start with this setting.
    # Note: The variance of performance of lbfgs may be quite large. Please repeat the experiments for more times.
    # Step 3: initialize the attacker.
    # attacker = DLGAttacker(iteration=300, working_dir='./dlg_attacker',
    #                        iteration_per_save=10, distance_measure='euclid')

    # Step 4: attack.
    # attacker.attack(model, test_dataset, device='cuda', class_number=10, save_img=True,
    #                 optim_backend='lbfgs')


if __name__ == '__main__':
    # Step 1: prepare the attack dataset.
    dataset_config = EasyDict(dict(data=dict(data_path='./data/cifar100', transforms=dict())))
    dataset = DATASET_REGISTRY.build('cifar100', dataset_config, train=False)

    # layers = ['pre_conv', 'layers.0.0.conv1', 'layers.0.0.conv2', 'layers.1.0.conv1', 'layers.1.0.conv2', 
    #           'layers.1.0.downsample.0', 'layers.2.0.conv1', 'layers.2.0.conv2', 'layers.2.0.downsample.0', 'fc']
    layers = ['pre_conv', 'layers.2.0.conv2', 'fc']
    part_args = []
    for layer in layers:
        part_args.append({"name": "contain", "keywords": layer})

    names_args = {'avg100': [{"name": "all"}],
                  'warm5_part5': part_args}
    # names_args = {'warm5_part5': part_args}
    
    for img_idx in range(4, 10):
            
        test_dataset = [dataset[img_idx]]

        for key in names_args:
            path_head = './visualize/' + key
            for para_arg in names_args[key]:
                attack_demo(test_dataset=test_dataset, img_idx=img_idx, path_head=path_head, parameter_args=para_arg)