import torch
import numpy as np
import dataloader as dt
import timm
import mlp
import utils
import argparse


attack = None
attack_type = None
device = None

parser = argparse.ArgumentParser()
parser.add_argument('--attack', type=str, help='An optional integer argument')
parser.add_argument('--attack_type', type=str, help='An optional integer argument')
parser.add_argument('--cuda', type=str, help='An optional integer argument')

args = parser.parse_args()

if args.attack_type is not None:
	attack_type = args.attack_type
if args.attack is not None:
	attack = args.attack
	attack = 'True' == attack

if not attack:
	attack_type = 'clean'

if args.cuda is None:
	device = 'cuda:0'
if args.cuda == '0':
	device = 'cuda:0'
if args.cuda == '1':
	device = 'cuda:1'

torch.set_printoptions(sci_mode=False)


model = timm.create_model('inception_v3', pretrained=True)
model.eval()
model = model.to(device)

model_mlp = mlp.MLP(8217, 2048).to(device)   # 9015 7350(early)
model_mlp_dir = 'lr-models/model_inception.pt'
model_mlp.load_state_dict(torch.load(model_mlp_dir))
model_mlp.eval()

loss_fn = torch.nn.MSELoss()

# resize_dim=299 for inception, 224 for the rest
test_loader = dt.get_loader(split='test', resize_dim=299, shuffle=False, batch=1)

layers = []
for name, module in model.named_modules():
	print(name)
	if "conv" in name or "head_drop" == name:  # head_drop
		print(f"Module hook registered for: {name}")
		module.register_forward_hook(lambda m, i, o: layers.append(o))


scores = []
for i, data in enumerate(test_loader):
	print(i)

	if attack_type == 'anda':
		if utils.skip_anda(i):
			continue

	inputs, labels = data
	labels = labels.type(torch.long)
	inputs = inputs.to(device)
	inputs = inputs.to(torch.float32)
	labels = labels.to(device)

	outputs = model(inputs)
	_, pre = torch.max(outputs.data, 1)

	# initial prediction is not correct
	if pre != labels:
		del layers
		layers = []
		continue

	if attack:
		inputs = utils.get_attack(inputs, labels, model, attack_type)

		del layers
		layers = []

		outputs_adv = model(inputs)
		_adv, pre_adv = torch.max(outputs_adv.data, 1)

		# attack is ineffective
		if pre_adv == labels:
			del layers
			layers = []
			continue

	f1 = layers[15][:, :3].view(1, 3 * 35 * 35)
	f2 = layers[25][:, :3].view(1, 3 * 35 * 35)
	f3 = layers[35][:, :3].view(1, 3 * 17 * 17)
	f_tot = torch.cat([f1, f2, f3], dim=1)

	o2 = layers[-1]

	mlp_out = model_mlp(f_tot)
	mlp_loss = loss_fn(mlp_out, o2)

	scores.append(mlp_loss.item())
	print('mlp_loss: {}'.format(mlp_loss))

	del layers
	layers = []

	print()

np.save('results/inception_{}.npy'.format(attack_type), scores)
