import itertools
import numpy as np
import torch
import cub_200
import torch.nn as nn
import torchvision
import os
import sys
import pretrainedmodels
import torch.nn.functional as F
from tqdm import tqdm
import scipy.stats
from LEEP import LEEP
import calc_gaussian_hard_subsets
import argparse

BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser()
parser.add_argument('--k',dest='k',type=float)
parser.set_defaults(k=None)
args = parser.parse_args()

if(args.k is None):
    print('K is required')
    sys.exit()


class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained is True:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 101)
        self.dropout = nn.Dropout2d(0.4)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        x = self.dropout(x)
        l0 = self.l0(x)
        return l0
    
def initialize_transferred_models() -> dict:
    models_list  = []
    
    models_list.append("imagenet_cub_densenet201")
    models_list.append("flowers_cub_densenet201")
    models_list.append("oxfordpets_cub_densenet201")
    models_list.append("stanforddogs_cub_densenet201")
    models_list.append("imagenet_cub_resnet101")
    models_list.append("oxfordpets_cub_resnet101")
    models_list.append("flowers_cub_resnet101")
    models_list.append("stanforddogs_cub_resnet18")
    models_list.append("stanforddogs_cub_vgg19")
    models_list.append("imagenet_cub_mobilenetv2")
    models_list.append("caltech_cub_resnet34")

    return models_list


def initialize_source_models() -> dict:
    models_dict  = dict()
    #    # Model Architecture: DenseNet201

    # Imagenet
    model = torchvision.models.densenet201(pretrained=True)
    # in_features_final = model.classifier.in_features
    # model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-densenet201-best_scheduler.pth'))
    model_name = "imagenet_densenet201"
    models_dict[model_name] = (model,1000)

    # Flowers 102
    model = torchvision.models.densenet201()
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-densenet201-best_scheduler.pth'))
    model_name = "flowers_densenet201"
    models_dict[model_name] = (model,102)

    # Oxford Pets
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=37) 
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))
    model_name = "oxfordpets_densenet201"
    models_dict[model_name] = (model,37)

    # Stanford Dogs
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=120) 
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-densenet201-best_scheduler.pth'))
    model_name = "stanforddogs_densenet201"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: ResNet101

    # Imagenet
    model = torchvision.models.resnet101(pretrained=True)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-resnet101-best_scheduler.pth'))
    model_name = "imagenet_resnet101"
    models_dict[model_name] = (model,1000)

    # Oxford Pets
    model = torchvision.models.resnet101() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))
    model_name = "oxfordpets_resnet101"
    models_dict[model_name] = (model,37)

    # Flowers 102
    model = torchvision.models.resnet101()
    in_features_final = model.fc.in_features
    model.fc= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-resnet101-best_scheduler.pth'))
    model_name = "flowers_resnet101"
    models_dict[model_name] = (model,102)

    #    # Model Architecture: ResNet18

    # Stanford Dogs
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-resnet18-best_scheduler.pth'))
    model_name = "stanforddogs_resnet18"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: VGG19

    # Stanford Dogs
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))
    model_name = "stanforddogs_vgg19"
    models_dict[model_name] = (model,120)


    #    # Model Architecture: MobileNetv2
    # Imagenet
    model = torchvision.models.mobilenet_v2(pretrained=True)
    # in_features_final = model.classifier[1].in_features
    # model.classifier[1] = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-mobilenetv2-best_scheduler.pth'))
    model_name = "imagenet_mobilenetv2"
    models_dict[model_name] = (model,1000)

    #    # Model Architecture: ResNet34
    model = ResNet34(pretrained=True)
    in_features_final = model.l0.in_features
    model.l0 = torch.nn.Linear(in_features=in_features_final,out_features=101)
    model.load_state_dict(torch.load('models/caltech101-pretrained.pth'))
    model_name = "caltech_resnet34"
    models_dict[model_name] = (model,101)

    return list(models_dict.keys()), models_dict

def calc_leep_scores(models, target_loader, target_ds):
    device = 'cuda'
    print(device)
    # source_dataset_name = 'imagenet'
    target_dataset_name = 'cub200'

    leep_scores = dict()
    ds_num = 0
    for model_name in tqdm(models.keys()):
        
        source_model = models[model_name][0]
        source_model = source_model.to(device)
        source_model.eval()
            
        source_classes = models[model_name][1]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        dummy_dist_penul = np.zeros((len(target_ds),source_classes))
        targets = []
        
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                # y -= 1
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()
                
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                dummy_dist_penul[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        topK = int(len(target_ds)*args.k)
        curr_leep = 0
        
        sorted_idxs = calc_gaussian_hard_subsets.get_hardness_ordering(dummy_dist_penul, targets)
        hard_idxs = sorted_idxs[-topK:]
        leep_scores[model_name] = LEEP(dummy_dist[hard_idxs], targets[hard_idxs])

    return leep_scores

def get_ensemble_acc_leep(target_model_names,leep_scores,target_loader,r=4):
    print("making ensemble combinations...")
    combinations = itertools.combinations(target_model_names,r)
    combinations = list(combinations)

    leep_scores_ensemble = np.zeros(len(combinations))
    accuracies = np.zeros(len(combinations))

    with open('./target_cub200.npy','rb') as f:
        temp = np.load(f)
        accuracies[:] = temp

    for i in tqdm(range(len(combinations))):
        combination = combinations[i]
        leep_score_array = []
        for model_name in combination:
            split = model_name.split("_",-1)
            source_name = split[0]+"_" + split[2]
            leep_score_array.append(leep_scores[source_name])

        accuracy = accuracies[i]
        
        leep_score_array = np.asarray(leep_score_array)
        leep = np.sum(leep_score_array)

        leep_scores_ensemble[i] = leep
#         print("Combination: ",i,", Accuracy: ",accuracy,", LEEP Score: ", leep)

    return leep_scores_ensemble, accuracies

if __name__ == "__main__":

    source_model_names, source_models = initialize_source_models()
    finetuned_model_names = initialize_transferred_models()

    transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    ])

    target_ds = cub_200.CUB200(root='/var/data/cub_200',train=True,transform=transform)
    target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)

    test_ds = cub_200.CUB200(root='/var/data/cub_200',train=False,transform=transform)
    test_loader = torch.utils.data.DataLoader(test_ds, shuffle=False, batch_size=BATCH_SIZE)

    leep_scores = calc_leep_scores(source_models, target_loader, target_ds)
    ensem_leep, acc = get_ensemble_acc_leep(finetuned_model_names,leep_scores,test_loader,4) 
    
    pcc,p = scipy.stats.pearsonr(ensem_leep, acc)
    print(args.k)
    print(pcc)

    







    


