import os
import sys
import numpy as np
import torchvision
from tqdm import tqdm
import time
import torch
import argparse
import flowers102
import stanford_cars
import stanford_dogs
import cub_200
import chest_xray_dataset
import oxford_pets
import caltech_dataset

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

parser = argparse.ArgumentParser()
parser.add_argument('--source',dest='source',type=str)
parser.add_argument('--target',dest='target',type=str)
parser.set_defaults(dataset=None)
args = parser.parse_args()

if(args.source and args.target):
    source_model_name = args.source
    source_dataset_name = source_model_name.split('_')[0]
    if source_dataset_name == 'stanford':
        source_dataset_name = 'stanford_dogs'
    target_dataset_name = args.target
    
    print(f'{source_dataset_name}_{target_dataset_name}')
else:
    print('No Dataset mentioned')
    sys.exit()
    
def get_dataset(dataset_name, preprocess_fn):
    if(dataset_name=='cifar100'):
        dataset = torchvision.datasets.CIFAR100(root='/var/data/cifar100/',download=True,transform=preprocess_fn)
    elif(dataset_name=='cifar10'):
        dataset = torchvision.datasets.CIFAR10(root='/var/data/cifar10',download=True,transform=preprocess_fn)
    elif(dataset_name=='fashionmnist'):
        dataset = torchvision.datasets.FashionMNIST(root='/var/data/fashionmnist',download=True,transform=preprocess_fn)
    elif(dataset_name=='caltech101'):
        dataset = caltech_dataset.Caltech(root='/var/data/caltech101',split='train',transform=preprocess_fn)
    elif(dataset_name=='flowers102'):
        dataset = flowers102.Flowers102(root='/var/data/flowers102',split='train',transform=preprocess_fn)
    elif(dataset_name=='stanford_cars'):
        dataset = stanford_cars.StanfordCars(root='/var/data/stanford_cars',split='train',transform=preprocess_fn)
    elif(dataset_name=='tiny-imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/tiny-imagenet/tiny-imagenet-200/train',transform=preprocess_fn)
    elif(dataset_name=='imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=preprocess_fn)
    elif(dataset_name=='cub200'):
        dataset = cub_200.CUB200(root='/var/data/cub_200/',train=True,transform=preprocess_fn)
    elif(dataset_name=='stanford_dogs'):
        dataset = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs',train=True,transform=preprocess_fn)
    elif(dataset_name=='chest_xray'):
        dataset = chest_xray_dataset.ChestXRayDataset(root='/var/data/chest_xray',train=True,transform=preprocess_fn)
    elif(dataset_name=='pets'):
        dataset = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=preprocess_fn)
    elif(dataset_name=='imagenette'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenette/imagenette2/train',transform=preprocess_fn)
    else:
        print('Dataset not recognized')
        raise NotImplementedError
        
    return dataset

def get_layer_output_shape(m_name, layer, ds_len):
    if('vgg19' in m_name):
        if(layer==1):
            return (ds_len,256,56,56)
        elif(layer==2):
            return (ds_len,512,28,28)
        elif(layer==3):
            return (ds_len,512,14,14)
        elif(layer==4):
            return (ds_len,512,7,7)
    
    if('resnet34' in m_name):
        if(layer==1):
            return (ds_len,128,28,28)
        elif(layer==2):
            return (ds_len,256,14,14)
        elif(layer==3):
            return (ds_len,512,7,7)
        
    if('resnet101' in m_name):
        if(layer==1):
            return (ds_len,512,28,28)
        elif(layer==2):
            return (ds_len,1024,14,14)
        elif(layer==3):
            return (ds_len,1024,14,14)
        elif(layer==4):
            return (ds_len,2048,7,7)
        
    if('resnet50' in m_name):
        if(layer==1):
            return (ds_len,512,28,28)
        elif(layer==2):
            return (ds_len,1024,14,14)
        elif(layer==3):
            return (ds_len,1024,14,14)
        elif(layer==4):
            return (ds_len,2048,7,7)
    
        
    raise NotImplementedError

def get_model_layer_nums(model_name):
    if 'resnet34' in model_name:
        return 3
    if 'vgg19' in model_name:
        return 4
    if 'resnet101' in model_name:
        return 4
    if 'resnet50' in model_name:
        return 4
    
    raise NotImplementedError
    
num_layers = get_model_layer_nums(source_model_name)
source_dataset = get_dataset(source_dataset_name,preprocess_fn=None)
target_dataset = get_dataset(target_dataset_name,preprocess_fn=None)

print(len(source_dataset))
print(len(target_dataset))

source_dataset_arrs = {}
target_dataset_arrs = {}

source_layer_output_shape = {
    f'layer{i}': get_layer_output_shape(source_model_name,i,len(source_dataset)) for i in range(1,num_layers+1)
}

target_layer_output_shape = {
    f'layer{i}': get_layer_output_shape(source_model_name,i,len(target_dataset)) for i in range(1,num_layers+1)
}

layer_padding = {
    'layer1': 2,
    'layer2': 2,
    'layer3': 1,
    'layer4': 1
}

for i in range(1,num_layers+1):
    source_dataset_arrs[f'layer{i}'] = np.memmap(
            os.path.join('/mnt2/ensemble_activations/',f'{source_model_name}_{source_dataset_name}_layer{i}_acts.dat'),mode='r',dtype=np.float32,shape=source_layer_output_shape[f'layer{i}']
    )
    
    target_dataset_arrs[f'layer{i}'] = np.memmap(
            os.path.join('/mnt2/ensemble_activations/',f'{source_model_name}_{target_dataset_name}_layer{i}_acts.dat'),mode='r',dtype=np.float32,shape=target_layer_output_shape[f'layer{i}']
    )
    
source_dataset_norms = {
   layer: np.zeros(len(source_dataset)) for layer in source_dataset_arrs 
}    

target_dataset_norms = {
   layer: np.zeros(len(target_dataset)) for layer in target_dataset_arrs 
}    

print('Computing norms')
for layer in source_dataset_arrs:
    p = layer_padding[layer]
    for i in tqdm(range(len(source_dataset))):
        source_dataset_norms[layer][i] = np.linalg.norm(source_dataset_arrs[layer][i][:,p:-p,p:-p].flatten())
    for i in tqdm(range(len(target_dataset))):
        target_dataset_norms[layer][i] = np.linalg.norm(target_dataset_arrs[layer][i][:,p:-p,p:-p].flatten())

act_batch_size = 64

image_pair_scores = {
    layer: np.zeros((len(source_dataset),len(target_dataset))) for layer in source_dataset_arrs
}

print('Starting Dot Product calculations')
for layer in source_dataset_arrs:
    print(layer)
    p = layer_padding[layer]
    start = time.time()
    image_pair_scores[layer][:] = np.tensordot(source_dataset_arrs[layer][:,:,p:-p,p:-p],target_dataset_arrs[layer][:,:,p:-p,p:-p],axes=([1,2,3],[1,2,3]))
    image_pair_scores[layer] = image_pair_scores[layer]/source_dataset_norms[layer][:,None]
    image_pair_scores[layer] = image_pair_scores[layer]/target_dataset_norms[layer][None,:]
    end = time.time()
    print(end-start)

    with open(os.path.join('/mnt2/ensemble_results',f'{source_model_name}_{source_dataset_name}_{target_dataset_name}_{layer}.npy'),'wb') as f:
        np.save(f,image_pair_scores[layer])