import itertools
import numpy as np
import torch
import flowers102
import torch.nn as nn
import torchvision
import os
import pretrainedmodels
import torch.nn.functional as F
from tqdm import tqdm
import scipy.stats
from LEEP import LEEP

BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained is True:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 101)
        self.dropout = nn.Dropout2d(0.4)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        x = self.dropout(x)
        l0 = self.l0(x)
        return l0
    
def initialize_transferred_models() -> dict:
    model_list = []
   
    model_list.append("imagenet_flowers_densenet201")
    model_list.append("stanforddogs_flowers_densenet201")
    model_list.append("oxfordpets_flowers_densenet201")
    model_list.append("imagenet_flowers_resnet101")
    model_list.append("oxfordpets_flowers_resnet101")
    model_list.append("cub_flowers_resnet18")
    model_list.append("stanforddogs_flowers_resnet18")
    model_list.append("cub_flowers_vgg19")
    model_list.append("stanforddogs_flowers_vgg19")
    model_list.append("imagenet_flowers_mobilenetv2")
    model_list.append("caltech_flowers_resnet34")
    
    return model_list

def initialize_source_models() -> dict:
    models_dict  = dict()
    #    # Model Architecture: DenseNet201

    # Imagenet
    model = torchvision.models.densenet201(pretrained=True)
    # in_features_final = model.classifier.in_features
    # model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-densenet201-best_scheduler.pth'))
    model_name = "imagenet_densenet201"
    models_dict[model_name] = (model,1000)

    # Stanford Dogs
    model = torchvision.models.densenet201()
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=120) 
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-densenet201-best_scheduler.pth'))
    model_name = "stanforddogs_densenet201"
    models_dict[model_name] = (model,120)

    # Oxford Pets
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=37) 
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))
    model_name = "oxfordpets_densenet201"
    models_dict[model_name] = (model,37)

    #    # Model Architecture: ResNet101

    # Imagenet
    model = torchvision.models.resnet101(pretrained=True)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-resnet101-best_scheduler.pth'))
    model_name = "imagenet_resnet101"
    models_dict[model_name] = (model,1000)

    # Oxford Pets
    model = torchvision.models.resnet101() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))
    model_name = "oxfordpets_resnet101"
    models_dict[model_name] = (model,37)

    #    # Model Architecture: ResNet18

    # CUB200
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=200)
    ckpt = torch.load(os.path.join("models","cub_classifier-ckpt.pth"))
    model.load_state_dict(ckpt['model'])
    model_name = "cub_resnet18"
    models_dict[model_name] = (model,200)

    # Stanford Dogs
    model = torchvision.models.resnet18() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-resnet18-best_scheduler.pth'))
    model_name = "stanforddogs_resnet18"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: VGG19

    # CUB200
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
    model.load_state_dict(torch.load('models/cub-pretrained-vgg19-best_scheduler.pth'))
    model_name = "cub_vgg19"
    models_dict[model_name] = (model,200)

    # Stanford Dogs
    model = torchvision.models.vgg19_bn() 
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))
    model_name = "stanforddogs_vgg19"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: MobileNetv2
    # Imagenet
    model = torchvision.models.mobilenet_v2(pretrained=True)
    # in_features_final = model.classifier[1].in_features
    # model.classifier[1] = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-mobilenetv2-best_scheduler.pth'))
    model_name = "imagenet_mobilenetv2"
    models_dict[model_name] = (model,1000)

    #    # Model Architecture: ResNet34
    model = ResNet34(pretrained=True)
    in_features_final = model.l0.in_features
    model.l0 = torch.nn.Linear(in_features=in_features_final,out_features=101)
    model.load_state_dict(torch.load('models/caltech101-pretrained.pth'))
    model_name = "caltech_resnet34"
    models_dict[model_name] = (model,101)

    return list(models_dict.keys()), models_dict

def get_model_layer_nums(model_name):
    print(model_name)
    if 'caltech' in model_name:
        return 3
    elif 'stanford' in model_name:
        return 4
    elif 'cub' in model_name:
        return 4
    elif 'pets' in model_name:
        return 4
    elif 'imagenet' in model_name:
        return 4
    
    raise NotImplementedError
    
def get_filename(source_model_name, target_ds_name, layer):
    if('cub' in source_model_name):
        return f'/mnt2/ensemble_results/cub200_vgg19_cub200_{target_ds_name}_layer{layer}.npy'
    elif('oxfordpets' in source_model_name):
        return f'/mnt2/ensemble_results/pets_resnet101_pets_{target_ds_name}_layer{layer}.npy'
    elif('stanford' in source_model_name):
        return f'/mnt2/ensemble_results/stanford_dogs_vgg19_stanford_dogs_{target_ds_name}_layer{layer}.npy'
    elif('caltech' in source_model_name):
        return f'/mnt2/ensemble_results/caltech101_resnet34_caltech101_{target_ds_name}_layer{layer}.npy'
    elif('imagenet' in source_model_name):
        return f'/mnt2/ensemble_results/imagenet_resnet50_imagenet_{target_ds_name}_layer{layer}.npy'
    else:
        raise NotImplementedError
        
def get_source_ds_length(source_model_name):
    if('caltech' in source_model_name):
        return 5784
    elif('cub' in source_model_name):
        return 5994
    elif('stanford' in source_model_name):
        return 12000
    elif('pets' in source_model_name):
        return 3680
    elif('flowers' in source_model_name):
        return 1020
    elif('imagenet' in source_model_name):
        return 128116

def calc_leep_scores(models, target_loader, target_ds):
    device = 'cuda'
    print(device)
    target_dataset_name = 'flowers102'

    
    leep_scores = dict()
    ds_num = 0
    for model_name in tqdm(models.keys()):
        
        source_model = models[model_name][0]
        source_model = source_model.to(device)
        source_model.eval()
        
        source_classes = models[model_name][1]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        num_layers = get_model_layer_nums(model_name)
        img_pair_scores = np.zeros((get_source_ds_length(model_name), len(target_ds)))
        for layer in range(num_layers):
            with open(get_filename(model_name, target_dataset_name, layer+1),'rb') as f:
                temp = np.load(f)
                img_pair_scores += temp
        img_pair_scores /= num_layers
        target_scores = np.mean(img_pair_scores, axis=0)
        sorted_idxs = np.argsort(target_scores)        
        
        #Remember y-=1 for flowers
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                y -= 1
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        num_buckets = 1
        topK = int(len(target_ds)/5)
        curr_leep = 0
        beta = 1
        multiplier = 0.9
        for b in range(num_buckets):
            hardest_idxs = list(sorted_idxs[:(b+1)*topK])
            curr_leep += beta*LEEP(dummy_dist[hardest_idxs], targets[hardest_idxs])
            beta = beta*multiplier
        leep_scores[model_name] = curr_leep

    return leep_scores

def get_ensemble_acc_leep(target_model_names,leep_scores,target_loader,test_loader,r=4):
    print("making ensemble combinations...")
    combinations = itertools.combinations(target_model_names,r)
    combinations = list(combinations)

    leep_scores_ensemble = np.zeros(len(combinations))
    accuracies = np.zeros(len(combinations))
    
    with open('./target_flowers.npy','rb') as f:
        accuracies = np.load(f)

    print("calculating accuracies...")

    for i in tqdm(range(len(combinations))):
        combination = combinations[i]
        leep_score_array = []
        for model_name in combination:
            split = model_name.split("_",-1)
            source_name = split[0]+"_" + split[2]
            leep_score_array.append(leep_scores[source_name])

        leep_score_array = np.asarray(leep_score_array)
        leep = np.sum(leep_score_array)

        leep_scores_ensemble[i] = leep
        accuracy = accuracies[i]
        print("Combination: ",i,", Accuracy: ",accuracy,", LEEP Score: ", leep)

    print("Complete!")   

    return leep_scores_ensemble, accuracies


def get_model_accuracy(model,target_loader):
    total = 0
    correct = 0
    model.to(device)
    model.eval()
    with torch.no_grad():
        for batch_idx,(x,y) in enumerate(target_loader):
            x = x.to(device)
            y = (y-1).to(device)
            out = model(x)
            # out = torch.tensor(out,dtype=int).to(device)
            _,pred = torch.max(out.data,1)
            total += y.shape[0]
            correct += (y==pred).sum().item()

    accuracy = 100*correct/total
    return accuracy


if __name__ == "__main__":

    source_model_names, source_models = initialize_source_models()
    finetuned_model_names = initialize_transferred_models()

    transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    ])

    target_ds = flowers102.Flowers102(root='/var/data/flowers102',split="train",transform=transform)
    target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)

    test_ds = flowers102.Flowers102(root='/var/data/flowers102',split='test',transform=transform)
    test_loader = torch.utils.data.DataLoader(test_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)
    leep_scores = calc_leep_scores(source_models, target_loader, target_ds)

    ensem_leep, acc = get_ensemble_acc_leep(finetuned_model_names,leep_scores,target_loader,test_loader,4) 
    
    pcc,p = scipy.stats.pearsonr(ensem_leep, acc)
    print(pcc)

    







    


