import torch
import torch.nn as nn
import torchvision
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt

def calc_resnet_repr(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    feature = torch.flatten(x, 1)
    return feature

#MODEL_NAME = 'naive'
#MODEL_NAME = 'reglast_100.0'
MODEL_NAME = 'reglast_20.0'
#MODEL_NAME = 'reglast_1.0'
#MODEL_NAME = 'jacALL_0.01'
#MODEL_NAME = 'jacALL_0.003'
#MODEL_NAME = 'jacALL_0.001'
#MODEL_NAME = 'jacALL_0.0001'

print (MODEL_NAME)

model = torchvision.models.resnet18()
model = model.to('cuda')
model.load_state_dict(torch.load('./saved_model/imagenet_%s.pth'%MODEL_NAME, map_location='cuda:0'))
#model.load_state_dict(torch.load('./saved_model/imagenet_%s_ckpt.pth.tar'%MODEL_NAME)['state_dict'])
print (model.fc)
model.eval()

criterion = nn.CrossEntropyLoss()

np.random.seed(0)
visualize_indices = np.random.choice(512,100,replace=False)

VIS_LR, VIS_DECAY = 1e1, 1e-2
#fig = plt.figure(figsize=(6,6))
fig = plt.figure(figsize=(12,12))
for i, idx in enumerate(tqdm(visualize_indices)):
    plt.subplot(10,10,i+1)
    for _ in range(100):
        x = torch.FloatTensor(1,3,224,224).uniform_(0,1).to('cuda')
        #init_feature_val = model.calc_representation(x).squeeze()[idx]
        init_feature_val = calc_resnet_repr(model,x).squeeze()[idx]
        if init_feature_val > 0:
            break
    if init_feature_val == 0:
        print ("Starting from 0")
    #for _ in range(2000):
    for _ in range(500):
        x.requires_grad_()
        with torch.enable_grad():
            feature_val = calc_resnet_repr(model,x).squeeze()[idx]
        grad = torch.autograd.grad(feature_val, [x])[0]
        grad_norm = grad.view(grad.shape[0], -1).norm(dim=1).view(-1,1,1,1)
        grad = grad / grad_norm.clamp(1e-6,9999)
        x = torch.clamp(x + VIS_LR*grad.detach(),0,1)
        x = x*(1-VIS_DECAY) + 0.5*VIS_DECAY
        x = x.detach()

    print (x.min(), x.max())
    plt.imshow(x.cpu().squeeze().numpy().transpose(1,2,0), vmin=0, vmax=1)
    plt.axis('off')
    #if i > 2:
    #    break
fig.savefig('figures/%s_repre_visualize_singlegen.pdf'%MODEL_NAME)
