import torch
import numpy as np
import torchvision
from tqdm import tqdm
import os
import sys
import argparse
import datasets_main.datasets as datasets
from argparse import ArgumentParser
import torch.nn as nn
import time

device = torch.device('cuda:'+str(4))
torch.cuda.set_device(device)

def parse_args():
    parser = ArgumentParser(description="Collecting activations")
    # Data parameters
    parser.add_argument('-ds', '--dataset-name', help='dataset name', default="bdd100k")

    parser.add_argument('-b', '--batch-size', help='minibatch size', default=8, type=int)
    # Image Formating
    parser.add_argument('--resize', help='resize the image', default=128, type=int)
    return parser.parse_args()

config = parse_args()


if(config.dataset_name):
    dataset_name = config.dataset_name
else:
    print('No Dataset mentioned')
    sys.exit()
    
#argument taken is the target dataset name   
target_dataset_name = config.dataset_name
target_loader, _ = datasets.load_dataset(config)

#set a particular source dataset name wrt which you wish to calculate similarity matrix
source_dataset_name = 'pascalvoc'
config.dataset_name = source_dataset_name
source_loader, _ = datasets.load_dataset(config)


print(len(source_loader.dataset))
print(len(target_loader.dataset))

source_dataset_arrs = {}
target_dataset_arrs = {}


#calculating similairty of activations of the following layers/blocks
blocks = {
    'block2':['unit3'],
    'block3':['unit5'],
    'block4':['unit1']
}

source_layer_output_shape = {
#     'block1_unit2':(len(source_dataset),256,56,56),
    # 'block2_unit1':(len(source_loader.dataset),512,16,16),
    'block2_unit3':(len(source_loader.dataset),512,16,16),
    # 'block3_unit1':(len(source_loader.dataset),1024,16,16),
    'block3_unit3':(len(source_loader.dataset),1024,16,16),
    'block3_unit5':(len(source_loader.dataset),1024,16,16),
    'block4_unit1':(len(source_loader.dataset),2048,16,16)
}

target_layer_output_shape = {
#     'block1_unit2':(len(target_dataset),256,56,56),
#     'block2_unit1':(len(target_dataset),512,28,28),
    'block2_unit3':(len(target_loader.dataset),512,16,16),
#     'block3_unit1':(len(target_dataset),1024,14,14),
    'block3_unit3':(len(target_loader.dataset),1024,16,16),
    'block3_unit5':(len(target_loader.dataset),1024,16,16),
    'block4_unit1':(len(target_loader.dataset),2048,16,16)
}

layer_padding = {
    'block2_unit1':2,
    'block2_unit3':2,
    'block3_unit1':1,
    'block3_unit3':1,
    'block3_unit5':1,
    'block4_unit1':1
}

for block in blocks:
    for unit in blocks[block]:
        source_dataset_arrs['{}_{}'.format(block,unit)] = np.memmap(
                os.path.join('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/pascalvoc_resnet50_fcn_activations/',block,unit,'{}_acts.dat'.format(source_dataset_name)),mode='r',dtype=np.float32,shape=source_layer_output_shape[f'{block}_{unit}']
            )
        target_dataset_arrs['{}_{}'.format(block,unit)] = np.memmap(
                os.path.join('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/pascalvoc_resnet50_fcn_activations/',block,unit,'{}_acts.dat'.format(target_dataset_name)),mode='r',dtype=np.float32,shape=target_layer_output_shape[f'{block}_{unit}']
            )
source_dataset_norms = {
   bu: np.zeros(len(source_loader.dataset)) for bu in source_dataset_arrs 
}    

target_dataset_norms = {
   bu: np.zeros(len(target_loader.dataset)) for bu in target_dataset_arrs 
}    

print('Computing norms')
for bu in source_dataset_arrs:
    p = layer_padding[bu]
    for i in tqdm(range(len(source_loader.dataset))):
        source_dataset_norms[bu][i] = np.linalg.norm(source_dataset_arrs[bu][i][:,p:-p,p:-p].flatten())
    for i in tqdm(range(len(target_loader.dataset))):
        target_dataset_norms[bu][i] = np.linalg.norm(target_dataset_arrs[bu][i][:,p:-p,p:-p].flatten())

act_batch_size = 64

image_pair_scores = {
    bu: np.zeros((len(source_loader.dataset),len(target_loader.dataset))) for bu in source_dataset_arrs
}

print('Starting Dot Product calculations')
for bu in source_dataset_arrs:
    print(bu)
    p = layer_padding[bu]
    start = time.time()
    image_pair_scores[bu][:] = np.tensordot(source_dataset_arrs[bu][:,:,p:-p,p:-p],target_dataset_arrs[bu][:,:,p:-p,p:-p],axes=([1,2,3],[1,2,3]))
    print(source_dataset_norms[bu][:,None])
    image_pair_scores[bu] = image_pair_scores[bu]/source_dataset_norms[bu][:,None]
    image_pair_scores[bu] = image_pair_scores[bu]/target_dataset_norms[bu][None,:]
    end = time.time()
    print(end-start)

print('Saving arrays')
for bu in image_pair_scores:
    with open(os.path.join('/home/ImageSegmentation/fineTuneModel_src/src_hleep_bdd/pascalvoc_resnet50_fcn_results/',f'{source_dataset_name}_{target_dataset_name}_{bu}.npy'),'wb') as f:
        np.save(f,image_pair_scores[bu])