import itertools
import numpy as np
import torch
import caltech_dataset
import torch.nn as nn
import torchvision
import os
import pretrainedmodels
import torch.nn.functional as F
from tqdm import tqdm
import scipy.stats
from LEEP import LEEP

BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained is True:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 101)
        self.dropout = nn.Dropout2d(0.4)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        x = self.dropout(x)
        l0 = self.l0(x)
        return l0

def initialize_transferred_models() -> dict:
    models_list = []
    
    models_list.append("imagenet_caltech_densenet201")
    models_list.append("flowers_caltech_densenet201")
    models_list.append("oxfordpets_caltech_densenet201")
    models_list.append("stanforddogs_caltech_densenet201")
    models_list.append("imagenet_caltech_resnet101")
    models_list.append("oxfordpets_caltech_resnet101")
    models_list.append("flowers_caltech_resnet101")
    models_list.append("stanforddogs_caltech_resnet18")
    models_list.append("cub_caltech_resnet18")
    models_list.append("stanforddogs_caltech_vgg19")
    models_list.append("cub_caltech_vgg19")
    models_list.append("imagenet_caltech_mobilenetv2")

    return models_list


def initialize_source_models() -> dict:
    models_dict  = dict()
    #    # Model Architecture: DenseNet201

    # Imagenet
    model = torchvision.models.densenet201(pretrained=True)
    # in_features_final = model.classifier.in_features
    # model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-densenet201-best_scheduler.pth'))
    model_name = "imagenet_densenet201"
    models_dict[model_name] = (model,1000)

    # Flowers 102
    model = torchvision.models.densenet201()
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-densenet201-best_scheduler.pth'))
    model_name = "flowers_densenet201"
    models_dict[model_name] = (model,102)

    # Oxford Pets
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=37) 
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))
    model_name = "oxfordpets_densenet201"
    models_dict[model_name] = (model,37)

    # Stanford Dogs
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=120) 
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-densenet201-best_scheduler.pth'))
    model_name = "stanforddogs_densenet201"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: ResNet101

    # Imagenet
    model = torchvision.models.resnet101(pretrained=True)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-resnet101-best_scheduler.pth'))
    model_name = "imagenet_resnet101"
    models_dict[model_name] = (model,1000)

    # Oxford Pets
    model = torchvision.models.resnet101() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))
    model_name = "oxfordpets_resnet101"
    models_dict[model_name] = (model,37)

    # Flowers 102
    model = torchvision.models.resnet101()
    in_features_final = model.fc.in_features
    model.fc= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-resnet101-best_scheduler.pth'))
    model_name = "flowers_resnet101"
    models_dict[model_name] = (model,102)

    #    # Model Architecture: ResNet18

    # Stanford Dogs
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-resnet18-best_scheduler.pth'))
    model_name = "stanforddogs_resnet18"
    models_dict[model_name] = (model,120)

    # CUB200
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=200)
    ckpt = torch.load(os.path.join("models","cub_classifier-ckpt.pth"))
    model.load_state_dict(ckpt['model'])
    model_name = "cub_resnet18"
    models_dict[model_name] = (model,200)

    #    # Model Architecture: VGG19

    # Stanford Dogs
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))
    model_name = "stanforddogs_vgg19"
    models_dict[model_name] = (model,120)

    # CUB200
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
    model.load_state_dict(torch.load('models/cub-pretrained-vgg19-best_scheduler.pth'))
    model_name = "cub_vgg19"
    models_dict[model_name] = (model,200)


    #    # Model Architecture: MobileNetv2
    # Imagenet
    model = torchvision.models.mobilenet_v2(pretrained=True)
    # in_features_final = model.classifier[1].in_features
    # model.classifier[1] = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-mobilenetv2-best_scheduler.pth'))
    model_name = "imagenet_mobilenetv2"
    models_dict[model_name] = (model,1000)

    return list(models_dict.keys()), models_dict

def get_model_layer_nums(model_name):
    print(model_name)
    if 'caltech' in model_name:
        return 3
    elif 'stanford' in model_name:
        return 4
    elif 'cub' in model_name:
        return 4
    elif 'pets' in model_name:
        return 4
    elif 'imagenet' in model_name:
        return 4
    elif 'flowers' in model_name:
        return 4
    
    raise NotImplementedError
    
def get_filename(source_model_name, target_ds_name, layer):
    if('cub' in source_model_name):
        return f'/mnt2/ensemble_results/cub200_vgg19_cub200_{target_ds_name}_layer{layer}.npy'
    elif('oxfordpets' in source_model_name):
        return f'/mnt2/ensemble_results/pets_resnet101_pets_{target_ds_name}_layer{layer}.npy'
    elif('stanford' in source_model_name):
        return f'/mnt2/ensemble_results/stanford_dogs_vgg19_stanford_dogs_{target_ds_name}_layer{layer}.npy'
    elif('caltech' in source_model_name):
        return f'/mnt2/ensemble_results/caltech101_resnet34_caltech101_{target_ds_name}_layer{layer}.npy'
    elif('imagenet' in source_model_name):
        return f'/mnt2/ensemble_results/imagenet_resnet50_imagenet_{target_ds_name}_layer{layer}.npy'
    elif('flowers' in source_model_name):
        return f'/mnt2/ensemble_results/flowers102_resnet101_flowers102_{target_ds_name}_layer{layer}.npy'
    else:
        raise NotImplementedError
        
def get_source_ds_length(source_model_name):
    if('caltech' in source_model_name):
        return 5784
    elif('cub' in source_model_name):
        return 5994
    elif('stanford' in source_model_name):
        return 12000
    elif('pets' in source_model_name):
        return 3680
    elif('flowers' in source_model_name):
        return 1020
    elif('imagenet' in source_model_name):
        return 128116

def calc_leep_scores(models, target_loader, target_ds):
    device = 'cuda'
    print(device)
    target_dataset_name = 'caltech101'

    leep_scores = dict()
    ds_num = 0
    for model_name in tqdm(models.keys()):
        
        source_model = models[model_name][0]
        source_model = source_model.to(device)
        source_model.eval()
        
        source_classes = models[model_name][1]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        num_layers = get_model_layer_nums(model_name)
        img_pair_scores = np.zeros((get_source_ds_length(model_name), len(target_ds)))
        for layer in range(num_layers):
            with open(get_filename(model_name, target_dataset_name, layer+1),'rb') as f:
                temp = np.load(f)
                img_pair_scores += temp
        img_pair_scores /= num_layers
        target_scores = np.mean(img_pair_scores, axis=0)
        sorted_idxs = np.argsort(target_scores)        
        
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                # y -= 1
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        num_buckets = 1
        topK = int(len(target_ds)/50)
        curr_leep = 0
        beta = 1
        multiplier = 0.9
        for b in range(num_buckets):
            hardest_idxs = list(sorted_idxs[:(b+1)*topK])
            print(len(hardest_idxs))
            curr_leep += beta*LEEP(dummy_dist[hardest_idxs], targets[hardest_idxs])
            beta = beta*multiplier
        leep_scores[model_name] = curr_leep

    return leep_scores

def get_ensemble_acc_leep(target_model_names,leep_scores,target_loader,r=4):
    print("making ensemble combinations...")
    combinations = itertools.combinations(target_model_names,r)
    combinations = list(combinations)

    leep_scores_ensemble = np.zeros(len(combinations))
    accuracies = np.zeros(len(combinations))
    with open('./target_caltech.npy','rb') as f:
        temp = np.load(f)
        accuracies[:] = temp

    for i in tqdm(range(len(combinations))):
        combination = combinations[i]
        leep_score_array = []
        for model_name in combination:
            split = model_name.split("_",-1)
            source_name = split[0]+"_" + split[2]
            leep_score_array.append(leep_scores[source_name])

        accuracy = accuracies[i]
        
        leep_score_array = np.asarray(leep_score_array)
        leep = np.sum(leep_score_array)

        leep_scores_ensemble[i] = leep
        print("Combination: ",i,", Accuracy: ",accuracy,", LEEP Score: ", leep)

    print("Complete!")   

    return leep_scores_ensemble, accuracies


if __name__ == "__main__":

    source_model_names, source_models = initialize_source_models()
    finetuned_model_names = initialize_transferred_models()

    transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    ])

    target_ds = caltech_dataset.Caltech(root='/var/data/caltech101',split="train",transform=transform)
    target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)

    test_ds = caltech_dataset.Caltech(root='/var/data/caltech101',split="test",transform=transform)
    test_loader = torch.utils.data.DataLoader(test_ds, shuffle=False, batch_size=BATCH_SIZE)

    leep_scores = calc_leep_scores(source_models, target_loader, target_ds)

    ensem_leep, acc = get_ensemble_acc_leep(finetuned_model_names,leep_scores,test_loader,4) 
    
    pcc,p = scipy.stats.pearsonr(ensem_leep, acc)
    print(pcc)

    







    


