import os 
import numpy as np
import torchvision
from tqdm import tqdm
import time
import torch
import argparse
import flowers102
import stanford_cars
import stanford_dogs
import cub_200
import chest_xray_dataset
import oxford_pets
import caltech_dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

source_dataset_name = 'imagenet'
# source_dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train')
source_dataset_len = 128116

parser = argparse.ArgumentParser()
parser.add_argument('--dataset',dest='dataset',type=str)
parser.set_defaults(dataset=None)
args = parser.parse_args()

if(args.dataset):
    target_dataset_name = args.dataset
else:
    print('No Dataset mentioned')
    sys.exit()
    
if(target_dataset_name=='cifar100'):
    target_dataset = torchvision.datasets.CIFAR100(root='/var/data/cifar100/')
elif(target_dataset_name=='cifar10'):
    target_dataset = torchvision.datasets.CIFAR10(root='/var/data/cifar10')
elif(target_dataset_name=='fashionmnist'):
    target_dataset = torchvision.datasets.FashionMNIST(root='/var/data/fashionmnist')
elif(target_dataset_name=='caltech101'):
    target_dataset = caltech_dataset.Caltech(root='/var/data/caltech101',split='train')
elif(target_dataset_name=='flowers102'):
    target_dataset = flowers102.Flowers102(root='/var/data/flowers102',split='train')
elif(target_dataset_name=='stanford_cars'):
    target_dataset = stanford_cars.StanfordCars(root='/var/data/stanford_cars',split='train')
elif(target_dataset_name=='cub200'):
    target_dataset = cub_200.CUB200(root='/var/data/cub_200/',train=True)
elif(target_dataset_name=='stanford_dogs'):
    target_dataset = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs')
elif(target_dataset_name=='chest_xray'):
    target_dataset = chest_xray_dataset.ChestXRayDataset(root='/var/data/chest_xray',train=True)
elif(target_dataset_name=='pets'):
    target_dataset = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval')
elif(target_dataset_name=='imagenette'):
    target_dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenette/imagenette2/train')
elif(target_dataset_name=='pacs_sketch'):
    target_dataset = torchvision.datasets.ImageFolder(root='/var/data/pacs/pacs_data/sketch/train')
else:
    print('Dataset not recognized')
    sys.exit()

print(source_dataset_len)
print(len(target_dataset))

source_dataset_arrs = {}
target_dataset_arrs = {}

blocks = {
#     'block1':['unit2'],
    'block2':['unit3'],
    'block3':['unit3','unit5'],
    'block4':['unit1']
}

source_layer_output_shape = {
#     'block1_unit2':(len(source_dataset),256,56,56),
    'block2_unit1':(source_dataset_len,512,28,28),
    'block2_unit3':(source_dataset_len,512,28,28),
    'block3_unit1':(source_dataset_len,1024,14,14),
    'block3_unit3':(source_dataset_len,1024,14,14),
    'block3_unit5':(source_dataset_len,1024,14,14),
    'block4_unit1':(source_dataset_len,2048,7,7)
}

target_layer_output_shape = {
#     'block1_unit2':(len(target_dataset),256,56,56),
#     'block2_unit1':(len(target_dataset),512,28,28),
    'block2_unit3':(len(target_dataset),512,28,28),
#     'block3_unit1':(len(target_dataset),1024,14,14),
    'block3_unit3':(len(target_dataset),1024,14,14),
    'block3_unit5':(len(target_dataset),1024,14,14),
    'block4_unit1':(len(target_dataset),2048,7,7)
}

layer_padding = {
    'block2_unit1':2,
    'block2_unit3':2,
    'block3_unit1':1,
    'block3_unit3':1,
    'block3_unit5':1,
    'block4_unit1':1
}

for block in blocks:
    for unit in blocks[block]:
        source_dataset_arrs['{}_{}'.format(block,unit)] = np.memmap(
                os.path.join('/mnt2/imagenet_resnet50_activations/',block,unit,'{}_acts.dat'.format(source_dataset_name)),mode='r',dtype=np.float32,shape=source_layer_output_shape[f'{block}_{unit}']
            )
        target_dataset_arrs['{}_{}'.format(block,unit)] = np.memmap(
                os.path.join('/mnt2/imagenet_resnet50_activations/',block,unit,'{}_acts.dat'.format(target_dataset_name)),mode='r',dtype=np.float32,shape=target_layer_output_shape[f'{block}_{unit}']
            )
source_dataset_norms = {
   bu: np.zeros(source_dataset_len) for bu in source_dataset_arrs 
}    

target_dataset_norms = {
   bu: np.zeros(len(target_dataset)) for bu in target_dataset_arrs 
}    

print('Computing norms')
for bu in source_dataset_arrs:
    p = layer_padding[bu]
    for i in tqdm(range(source_dataset_len)):
        source_dataset_norms[bu][i] = np.linalg.norm(source_dataset_arrs[bu][i][:,p:-p,p:-p].flatten())
    for i in tqdm(range(len(target_dataset))):
        target_dataset_norms[bu][i] = np.linalg.norm(target_dataset_arrs[bu][i][:,p:-p,p:-p].flatten())

act_batch_size = 64

image_pair_scores = {
    bu: np.zeros((source_dataset_len,len(target_dataset))) for bu in source_dataset_arrs
}

print('Starting Dot Product calculations')
for bu in source_dataset_arrs:
    print(bu)
    p = layer_padding[bu]
    start = time.time()
    image_pair_scores[bu][:] = np.tensordot(source_dataset_arrs[bu][:,:,p:-p,p:-p],target_dataset_arrs[bu][:,:,p:-p,p:-p],axes=([1,2,3],[1,2,3]))
    image_pair_scores[bu] = image_pair_scores[bu]/source_dataset_norms[bu][:,None]
    image_pair_scores[bu] = image_pair_scores[bu]/target_dataset_norms[bu][None,:]
    end = time.time()
    print(end-start)

print('Saving arrays')
for bu in image_pair_scores:
    with open(os.path.join('/mnt2/imagenet_resnet50_results',f'{source_dataset_name}_{target_dataset_name}_{bu}.npy'),'wb') as f:
        np.save(f,image_pair_scores[bu])