import torch
import numpy as np
import torchvision
from tqdm import tqdm
import os
import sys
import argparse
import flowers102
import stanford_cars
import stanford_dogs
import cub_200
import chest_xray_dataset
import oxford_pets
import caltech_dataset

parser = argparse.ArgumentParser()
parser.add_argument('--dataset',dest='dataset',type=str)
parser.set_defaults(dataset=None)
args = parser.parse_args()

if(args.dataset):
    dataset_name = args.dataset
else:
    print('No Dataset mentioned')
    sys.exit()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

model = torchvision.models.resnet50(pretrained=True)
model = model.to(device)

def _convert_image_to_rgb(image):
    return image.convert("RGB")

preprocess_fn = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
#     _convert_image_to_rgb,
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

if(dataset_name=='cifar100'):
    dataset = torchvision.datasets.CIFAR100(root='/var/data/cifar100/',download=True,transform=preprocess_fn)
elif(dataset_name=='cifar10'):
    dataset = torchvision.datasets.CIFAR10(root='/var/data/cifar10',download=True,transform=preprocess_fn)
elif(dataset_name=='fashionmnist'):
    dataset = torchvision.datasets.FashionMNIST(root='/var/data/fashionmnist',download=True,transform=preprocess_fn)
elif(dataset_name=='caltech101'):
    dataset = caltech_dataset.Caltech(root='/var/data/caltech101',split='train',transform=preprocess_fn)
elif(dataset_name=='flowers102'):
    dataset = flowers102.Flowers102(root='/var/data/flowers102',split='train',transform=preprocess_fn)
elif(dataset_name=='stanford_cars'):
    dataset = stanford_cars.StanfordCars(root='/var/data/stanford_cars',split='train',transform=preprocess_fn)
elif(dataset_name=='tiny-imagenet'):
    dataset = torchvision.datasets.ImageFolder(root='/var/data/tiny-imagenet/tiny-imagenet-200/train',transform=preprocess_fn)
elif(dataset_name=='imagenet'):
    dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=preprocess_fn)
elif(dataset_name=='cub200'):
    dataset = cub_200.CUB200(root='/var/data/cub_200/',train=True,transform=preprocess_fn)
elif(dataset_name=='stanford_dogs'):
    dataset = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs',train=True,transform=preprocess_fn)                            
elif(dataset_name=='chest_xray'):
    dataset = chest_xray_dataset.ChestXRayDataset(root='/var/data/chest_xray',train=True,transform=preprocess_fn)
elif(dataset_name=='pets'):
    dataset = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=preprocess_fn)
elif(dataset_name=='imagenette'):
    dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenette/imagenette2/train',transform=preprocess_fn)
elif(dataset_name=='pacs_sketch'):
    dataset = torchvision.datasets.ImageFolder(root='/var/data/pacs/pacs_data/sketch/train',transform=preprocess_fn)
else:
    print('Dataset not recognized')
    sys.exit()
    

#Units in each block are 0-indexed
#Blocks are 1-indexed
batch_size = 64
data_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,num_workers=2)

num_batches = len(data_loader)
curr_batch_start = 0
curr_batch_size = batch_size


print(len(dataset))
layer_outputs = {
    'block1_unit2_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block1/unit2/{}_acts.dat'.format(dataset_name),shape=(len(dataset),256,56,56),mode='w+',dtype=np.float32),
    'block2_unit1_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block2/unit1/{}_acts.dat'.format(dataset_name),shape=(len(dataset),512,28,28),mode='w+',dtype=np.float32),
    'block2_unit3_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block2/unit3/{}_acts.dat'.format(dataset_name),shape=(len(dataset),512,28,28),mode='w+',dtype=np.float32),
    'block3_unit1_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block3/unit1/{}_acts.dat'.format(dataset_name),shape=(len(dataset),1024,14,14),mode='w+',dtype=np.float32),
    'block3_unit3_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block3/unit3/{}_acts.dat'.format(dataset_name),shape=(len(dataset),1024,14,14),mode='w+',dtype=np.float32),
    'block3_unit5_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block3/unit5/{}_acts.dat'.format(dataset_name),shape=(len(dataset),1024,14,14),mode='w+',dtype=np.float32),
    'block4_unit1_relu':np.memmap('/mnt2/imagenet_resnet50_activations/block4/unit1/{}_acts.dat'.format(dataset_name),shape=(len(dataset),2048,7,7),mode='w+',dtype=np.float32)
}



def forward_hook_builder(layer_name):
    def save_hook(self,input_tensor,output_tensor):
        global curr_batch_start
        global batch_size
        layer_outputs[layer_name][curr_batch_start:curr_batch_start+curr_batch_size] = np.copy(output_tensor.detach().cpu().numpy())

    return save_hook


model.layer1[2].register_forward_hook(forward_hook_builder('block1_unit2_relu'))
model.layer2[1].register_forward_hook(forward_hook_builder('block2_unit1_relu'))
model.layer2[3].register_forward_hook(forward_hook_builder('block2_unit3_relu'))
model.layer3[1].register_forward_hook(forward_hook_builder('block3_unit1_relu'))
model.layer3[3].register_forward_hook(forward_hook_builder('block3_unit3_relu'))
model.layer3[5].register_forward_hook(forward_hook_builder('block3_unit5_relu'))
model.layer4[1].register_forward_hook(forward_hook_builder('block4_unit1_relu'))

for batch_idx,(img, data) in tqdm(enumerate(data_loader),total=num_batches):
    img = img.to(device)
    curr_batch_size = img.shape[0]
    model(img)
    curr_batch_start += curr_batch_size
