#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import inspect
import torch
from torch.utils.data import DataLoader,Dataset

from torch.distributions import Uniform
import pytest

import logging
from ats.attacks_meta import AttackMeta, AttackFactory, get_attack_instance, get_attack_args_kwargs
from tqdm import tqdm

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class dummy_model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        #input is N x 3 x 32 x 32
        self.conv1 = torch.nn.Conv2d(3, 6, 5)
        #output is N x 6 x 28 x 28
        self.pool = torch.nn.MaxPool2d(2, 2)
        #output is N x 6 x 14 x 14
        self.conv2 = torch.nn.Conv2d(6, 16, 5)
        #output is N x 4 x 10 x 10
        self.fc1 = torch.nn.Linear(4 * 10 * 10, 120)
        self.fc2 = torch.nn.Linear(120, 84)
        self.fc3 = torch.nn.Linear(84, 10)
    
    def forward(self, x):
        x = self.pool(torch.nn.functional.relu(self.conv1(x)))
        x = self.pool(torch.nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 4 * 10 * 10)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)

        return x
 


def test_attacks():
    N, C, H, W = 10, 3, 32, 32
    images = torch.rand(N, C, H, W)
    labels = torch.randint(0, 10, (N,))
    print(labels)

    model = dummy_model()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    images = images.to(device)
    labels = labels.to(device)
    try:
        for attack_name, attack_class in tqdm(AttackMeta.registered_attacks.items()):
            print(f"Testing {attack_name}...")
            args, kwargs = [], {}

            if attack_name in ["", "LGV"]:
                continue
            logger.debug(f"Testing {attack_name}...")

            arg_names, kwarg_names = get_attack_args_kwargs(attack_name)
            print(arg_names, kwarg_names)
            if "trainloader" in arg_names:
                train_set = torch.utils.data.TensorDataset(images, labels)
                trainloader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
                kwargs["trainloader"] = trainloader
                
            attack = get_attack_instance(attack_type = attack_name, model = model, *args, **kwargs)
            attack.set_return_type(type="float")
            results = attack(images, labels)

            logger.debug(results)
            logger.debug(f"Testing {attack_name}...OK")
    except RuntimeError as e:
        print(f"Cannot test attack {attack_name}: ", e)
        


if __name__ == '__main__':
    pytest.main([__file__])
    # test_attacks()
