import torch
import numpy as np
import torchvision
from tqdm import tqdm
import os
import sys
import argparse
import datasets_main.datasets as datasets
from argparse import ArgumentParser
import torch.nn as nn

def parse_args():
    parser = ArgumentParser(description="Collecting activations")
    # Data parameters
    parser.add_argument('-ds', '--dataset-name', help='dataset name', default="bdd100k")
    parser.add_argument('-b', '--batch-size', help='minibatch size', default=8, type=int)
    parser.add_argument('-out', '--out-channels', help='output channels', default=34, type=int)
    # Image Formating
    parser.add_argument('--resize', help='resize the image', default=128, type=int)
    return parser.parse_args()

config = parse_args()


if(config.dataset_name):
    dataset_name = config.dataset_name
else:
    print('No Dataset mentioned')
    sys.exit()

device = torch.device('cuda:'+str(4))
torch.cuda.set_device(device)

#loading the pretrained model
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
last_layer_input = model.classifier[4].in_channels
replaced_last_layer = nn.Conv2d(in_channels = last_layer_input, out_channels=config.out_channels, kernel_size=(1,1), stride=(1,1))
model.classifier[4] = replaced_last_layer
model.load_state_dict(torch.load("/home/ImageSegmentation/ckpts/idd/fcn_resnet50-03-09-2022-0618/dataset-idd-model-fcn_resnet50-epoch1-0.04053.pt", map_location=torch.device(device)))
model = model.to(device)

    
train_data_loader, val_data_loader = datasets.load_dataset(config)

data_loader = train_data_loader
#Units in each block are 0-indexed
#Blocks are 1-indexed


num_batches = len(data_loader)
curr_batch_start = 0
curr_batch_size = config.batch_size

paths = []

paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block1/unit2')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block2/unit1')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block2/unit3')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit1')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit3')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit5')
paths.append('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block4/unit1')

for path in paths:
    if not os.path.isdir(path):
        os.makedirs(path)


print(len(data_loader.dataset))
layer_outputs = {
    'block1_unit2_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block1/unit2/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),256,32,32),mode='w+',dtype=np.float32),
    'block2_unit1_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block2/unit1/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),512,16,16),mode='w+',dtype=np.float32),
    'block2_unit3_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block2/unit3/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),512,16,16),mode='w+',dtype=np.float32),
    'block3_unit1_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit1/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),1024,16,16),mode='w+',dtype=np.float32),
    'block3_unit3_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit3/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),1024,16,16),mode='w+',dtype=np.float32),
    'block3_unit5_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block3/unit5/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),1024,16,16),mode='w+',dtype=np.float32),
    'block4_unit1_relu':np.memmap('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/idd_resnet50_fcn_activations/block4/unit1/{}_acts.dat'.format(dataset_name),shape=(len(data_loader.dataset),2048,16,16),mode='w+',dtype=np.float32)
}



def forward_hook_builder(layer_name):
    def save_hook(self,input_tensor,output_tensor):
        global curr_batch_start
        global batch_size
        layer_outputs[layer_name][curr_batch_start:curr_batch_start+curr_batch_size] = np.copy(output_tensor.detach().cpu().numpy())

    return save_hook


model.backbone.layer1[2].register_forward_hook(forward_hook_builder('block1_unit2_relu'))
model.backbone.layer2[1].register_forward_hook(forward_hook_builder('block2_unit1_relu'))
model.backbone.layer2[3].register_forward_hook(forward_hook_builder('block2_unit3_relu'))
model.backbone.layer3[1].register_forward_hook(forward_hook_builder('block3_unit1_relu'))
model.backbone.layer3[3].register_forward_hook(forward_hook_builder('block3_unit3_relu'))
model.backbone.layer3[5].register_forward_hook(forward_hook_builder('block3_unit5_relu'))
model.backbone.layer4[1].register_forward_hook(forward_hook_builder('block4_unit1_relu'))

for batch_idx,(img, data) in tqdm(enumerate(data_loader),total=num_batches):
    img = img.to(device)
    curr_batch_size = img.shape[0]
    model(img)
    curr_batch_start += curr_batch_size
