import numpy as np
import os
import sys
import torch
import torchvision
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from LEEP import LEEP
import random
import flowers102
# import caltech101.caltech_dataset as caltech101
import pretrainedmodels 
from resnet34_caltech import ResNet34 
import cub_200
import stanford_dogs
import oxford_pets
import caltech_dataset
import torch.nn as nn
import calc_gaussian_hard_subsets

BATCH_SIZE = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

models = {
    'resnet_50': torchvision.models.resnet50(pretrained=True),
    'resnet_152': torchvision.models.resnet152(pretrained=True),
    'mobilenet_v2': torchvision.models.mobilenet_v2(pretrained=True),
    'densenet_201': torchvision.models.densenet201(pretrained=True),
    'densenet_169': torchvision.models.densenet169(pretrained=True),
    'densenet_121': torchvision.models.densenet121(pretrained=True),
    'resnet_101': torchvision.models.resnet101(pretrained=True)
}

def main():
    print(device)
    source_dataset_name = 'imagenet'
    #REMOVE y-=1 when changing from flowers
    target_dataset_name = 'pets'

    source_ds_len = 128116
    #PSA: Need to replace in one other place; Remove y-=1 when moving away from flowers
    target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='train',transform=transform)
    
    leep_scores = np.zeros((len(models)))
    for ds_num,model_name in tqdm(enumerate(models),total=len(models)):
        
        source_model = models[model_name]
        source_model = source_model.to(device)
        source_model.eval()
        
        target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='train',transform=transform)
        target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, batch_size=BATCH_SIZE)
            
        dummy_dist = np.zeros((len(target_ds),1000))
        targets = []
        #Remember y-=1 for flowers
        for batch_idx,(x,y) in enumerate(target_loader):
            curr_batch_size = len(x)
            x = x.to(device)
#             y -= 1
            out = source_model(x)
            out = torch.nn.functional.softmax(out,dim=1)
            targets.extend(list(y))
            dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

#         hard_idxs = calc_gaussian_hard_subsets.get_hard_subset(dummy_dist, targets, K=10)
#         hard_idxs = np.arange(len(targets))
        
        sorted_idxs = calc_gaussian_hard_subsets.get_hardness_ordering(dummy_dist, targets)
        K = int(len(sorted_idxs))
        hard_idxs = sorted_idxs[-K:]
#         Note that we use last K samples, as further is harder with gaussian, opposite of similarity based method

        print(len(hard_idxs)/len(targets))
        targets = np.array(targets)
        temp_targets = targets[hard_idxs]
        temp_classes = list(np.unique(temp_targets))
        modified_temp_targets = [temp_classes.index(t) for t in temp_targets]
        leep_scores[ds_num] = LEEP(dummy_dist[hard_idxs], np.array(modified_temp_targets))

    #Remember to use best for cub and flowers
    f_name = 'accuracies/{}_to_{}_arch_finetune_accuracies_seed0.npy'.format(source_dataset_name,target_dataset_name)
    with open(f_name,'rb') as f:
        accuracies = np.load(f)
        
    with open(f'./{source_dataset_name}_{target_dataset_name}_leep_scores_arch.npy','wb') as f:
        np.save(f, leep_scores)
    
    pcc,p = scipy.stats.pearsonr(leep_scores, accuracies)
    print(target_dataset_name)
    print(pcc)

if __name__ == '__main__':
    main()