import numpy as np
import os
import sys
import torch
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from GBC import GBC
from gbc_tf import get_gbc_score
import random
import flowers102
# import caltech101.caltech_dataset as caltech101
import pretrainedmodels 
from resnet34_caltech import ResNet34 
import cub_200
import stanford_dogs
import oxford_pets
import caltech_dataset
import torch.nn as nn
import torchvision

BATCH_SIZE = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

models = {
    'resnet_50': torchvision.models.resnet50(pretrained=True),
    'resnet_152': torchvision.models.resnet152(pretrained=True),
    'mobilenet_v2': torchvision.models.mobilenet_v2(pretrained=True),
    'densenet_201': torchvision.models.densenet201(pretrained=True),
    'densenet_169': torchvision.models.densenet169(pretrained=True),
    'densenet_121': torchvision.models.densenet121(pretrained=True),
    'resnet_101': torchvision.models.resnet101(pretrained=True)
}

def main():
    print(device)
    source_dataset_name = 'imagenet'
    #REMOVE y-=1 when changing from flowers
    target_dataset_name = 'pets'

    source_ds_len = 128116
    #PSA: Need to replace in one other place; Remove y-=1 when moving away from flowers
    target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=transform)
    NUM_TARGET_CLASSES = 101
    
    block_units = ['block2_unit3', 'block3_unit3', 'block3_unit5', 'block4_unit1']
    lent = len(block_units)
    img_pair_scores = np.zeros((source_ds_len,len(target_ds)),dtype=np.float64)
    for bu in tqdm(block_units):
        with open('/mnt2/imagenet_resnet50_results/{}_{}_{}.npy'.format(source_dataset_name, target_dataset_name, bu),'rb') as f:
            temp = np.load(f)
            print(temp.shape)
            img_pair_scores += temp

    img_pair_scores /= len(block_units)
    target_scores = np.mean(img_pair_scores,axis=0)
    sorted_idxs = np.argsort(target_scores)

    gbc_scores = np.zeros((len(models)))
    for ds_num,model_name in tqdm(enumerate(models),total=len(models)):
        
        source_model = models[model_name]
        source_model = source_model.to(device)
        
        if 'resnet' in model_name:
            out_dim = source_model.fc.in_features
            source_model.fc = nn.Identity()
        elif 'mobilenet' in model_name:
            out_dim = source_model.classifier[1].in_features
            source_model.classifier[1] = nn.Identity()
        elif 'densenet' in model_name:
            out_dim = source_model.classifier.in_features
            source_model.classifier = nn.Identity()
            
        source_model.eval()
        
        target_ds = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=transform)
        target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, batch_size=BATCH_SIZE)
            
        dummy_dist = np.zeros((len(target_ds),out_dim))
        targets = []
        #Remember y-=1 for flowers
        for batch_idx,(x,y) in enumerate(target_loader):
            curr_batch_size = len(x)
            x = x.to(device)
#             y -= 1
            out = source_model(x)
            out = torch.nn.functional.softmax(out,dim=1)
            targets.extend(list(y))
            dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        beta = 1
        multiplier = 0.9
        targets = np.array(targets)
        num_buckets = 1
        topK = int(len(target_ds))
        for num in range(num_buckets):
            temp_idxs = sorted_idxs[:(num+1) * topK]
            temp_targets = targets[temp_idxs]
            temp_classes = list(np.unique(temp_targets))
#             print(temp_classes)
#             print(len(temp_classes))
            modified_temp_targets = [temp_classes.index(t) for t in temp_targets]
            gbc_scores[ds_num] += beta*GBC(dummy_dist[temp_idxs], modified_temp_targets, num_classes=len(temp_classes))
#             gbc_scores[ds_num] += beta*get_gbc_score(dummy_dist[temp_idxs], modified_temp_targets, 'spherical')
            
            beta = multiplier * beta        
        
    #Remember to use best for cub and flowers
    f_name = 'accuracies/{}_to_{}_arch_finetune_accuracies_seed0_best.npy'.format(source_dataset_name,target_dataset_name)
    with open(f_name,'rb') as f:
        accuracies = np.load(f)
    
    pcc,p = scipy.stats.pearsonr(gbc_scores, accuracies)
    print(pcc)

if __name__ == '__main__':
    main()