import numpy as np
import os
import sys
import torch
import torchvision
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from LEEP import LEEP
import random
import flowers102
# import caltech101.caltech_dataset as caltech101
import pretrainedmodels 
from resnet34_caltech import ResNet34 
import cub_200
import stanford_dogs
import oxford_pets
import torch.nn as nn

BATCH_SIZE = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

models = {
    'resnet_50': torchvision.models.resnet50(pretrained=True),
    'resnet_152': torchvision.models.resnet152(pretrained=True),
    'mobilenet_v2': torchvision.models.mobilenet_v2(pretrained=True),
    'densenet_201': torchvision.models.densenet201(pretrained=True),
    'densenet_169': torchvision.models.densenet169(pretrained=True),
    'densenet_121': torchvision.models.densenet121(pretrained=True),
    'resnet_101': torchvision.models.resnet101(pretrained=True)
}

def main():
    print(device)
    source_dataset_name = 'imagenet'
    target_dataset_name = 'pets'

    #PSA: Need to replace in one other place; Remove y-=1 for flowers
    target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=transform)
    leep_scores = np.zeros((len(models)))
    
    img_pair_scores = np.zeros((128116,len(target_ds)))
    block_units = ['block2_unit3','block3_unit3','block3_unit5','block4_unit1']
    for bu in tqdm(block_units):
        with open('/mnt2/imagenet_resnet50_results/{}_{}_{}.npy'.format(source_dataset_name, target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            print(temp.shape)
            img_pair_scores += temp  #for cifar 100, fmnist

    img_pair_scores /= len(block_units)
    target_scores = np.mean(img_pair_scores,axis=0)
    sorted_idxs = np.argsort(target_scores)
    
    for ds_num,model_name in tqdm(enumerate(models),total=len(models)):
        
        source_model = models[model_name]
        source_model = source_model.to(device)
        source_model.eval()
        
        target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=transform)
        target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)
            
        source_classes = 1000
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        #Remember y-=1 for flowers
        for batch_idx,(x,y) in enumerate(target_loader):
            curr_batch_size = len(x)
            x = x.to(device)
            y -= 1
            out = source_model(x)
            out = torch.nn.functional.softmax(out,dim=1)
            targets.extend(list(y))
            dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        leep_scores[ds_num] = LEEP(dummy_dist,targets)
        
        
    with open(f'./{source_dataset_name}_{target_dataset_name}_leep_scores_arch.npy','wb') as f:
        np.save(f, leep_scores)
        
    f_name = 'accuracies/{}_to_{}_arch_finetune_accuracies_seed0_best.npy'.format(source_dataset_name,target_dataset_name)
    with open(f_name,'rb') as f:
        accuracies = np.load(f)
        
    
    pcc,p = scipy.stats.pearsonr(leep_scores, accuracies)
    print(pcc)

if __name__ == '__main__':
    main()