import itertools
import numpy as np
import torch
import stanford_dogs
import torch.nn as nn
import torchvision
import os
import pretrainedmodels
import torch.nn.functional as F
from tqdm import tqdm
import scipy.stats
from LEEP import LEEP

BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained is True:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 101)
        self.dropout = nn.Dropout2d(0.4)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        x = self.dropout(x)
        l0 = self.l0(x)
        return l0
    
def initialize_transferred_models() -> dict:
    model_list  = []
    
    model_list.append( "imagenet_stanforddogs_densenet201")
    model_list.append("flowers_stanforddogs_densenet201")
    model_list.append("oxfordpets_stanforddogs_densenet201")
    model_list.append("imagenet_stanforddogs_resnet101")
    model_list.append("oxfordpets_stanforddogs_resnet101")
    model_list.append("flowers_stanforddogs_resnet101")
    model_list.append("cub_stanforddogs_resnet18")
    model_list.append("cub_stanforddogs_vgg19")
    model_list.append("imagenet_stanforddogs_mobilenetv2")
    model_list.append("caltech_stanforddogs_resnet34")

    return model_list

def initialize_source_models() -> dict:
    models_dict  = dict()
    #    # Model Architecture: DenseNet201

    # Imagenet
    model = torchvision.models.densenet201(pretrained=True)
    # in_features_final = model.classifier.in_features
    # model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-densenet201-best_scheduler.pth'))
    model_name = "imagenet_densenet201"
    models_dict[model_name] = (model,1000)

    # Flowers 102
    model = torchvision.models.densenet201()
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-densenet201-best_scheduler.pth'))
    model_name = "flowers_densenet201"
    models_dict[model_name] = (model,102)

    # Oxford Pets
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=37) 
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))
    model_name = "oxfordpets_densenet201"
    models_dict[model_name] = (model,37)

    #    # Model Architecture: ResNet101

    # Imagenet
    model = torchvision.models.resnet101(pretrained=True)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-resnet101-best_scheduler.pth'))
    model_name = "imagenet_resnet101"
    models_dict[model_name] = (model,1000)

    # Oxford Pets
    model = torchvision.models.resnet101() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))
    model_name = "oxfordpets_resnet101"
    models_dict[model_name] = (model,37)

    # Flowers 102
    model = torchvision.models.resnet101()
    in_features_final = model.fc.in_features
    model.fc= torch.nn.Linear(in_features=in_features_final,out_features=102) 
    model.load_state_dict(torch.load('models/flowers102-pretrained-resnet101-best_scheduler.pth'))
    model_name = "flowers_resnet101"
    models_dict[model_name] = (model,102)

    #    # Model Architecture: ResNet18

    # CUB200
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=200)
    ckpt = torch.load(os.path.join("models","cub_classifier-ckpt.pth"))
    model.load_state_dict(ckpt['model'])
    model_name = "cub_resnet18"
    models_dict[model_name] = (model,200)

    #    # Model Architecture: VGG19

    # CUB200
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
    model.load_state_dict(torch.load('models/cub-pretrained-vgg19-best_scheduler.pth'))
    model_name = "cub_vgg19"
    models_dict[model_name] = (model,200)


    #    # Model Architecture: MobileNetv2
    # Imagenet
    model = torchvision.models.mobilenet_v2(pretrained=True)
    # in_features_final = model.classifier[1].in_features
    # model.classifier[1] = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-mobilenetv2-best_scheduler.pth'))
    model_name = "imagenet_mobilenetv2"
    models_dict[model_name] = (model,1000)

    #    # Model Architecture: ResNet34
    model = ResNet34(pretrained=True)
    in_features_final = model.l0.in_features
    model.l0 = torch.nn.Linear(in_features=in_features_final,out_features=101)
    model.load_state_dict(torch.load('models/caltech101-pretrained.pth'))
    model_name = "caltech_resnet34"
    models_dict[model_name] = (model,101)

    return list(models_dict.keys()), models_dict


def calc_leep_scores(models, target_loader, target_ds):
    device = 'cuda'
    print(device)
    # source_dataset_name = 'imagenet'
    target_dataset_name = 'flowers102'

    # source_ds = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=transform)
    #PSA: Need to replace in one other place; Remove y-=1 for flowers
    # target_ds = flowers102.Flowers102(root='data/flowers102',split="train",transform=transform)
    
    leep_scores = dict()
    ds_num = 0
    for model_name in tqdm(models.keys()):
        
        source_model = models[model_name][0]
        source_model = source_model.to(device)
        source_model.eval()
        
        # target_ds = flowers102.Flowers102(root='data/flowers102',train=True,transform=transform)
        # target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)
            
        source_classes = models[model_name][1]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                # y -= 1
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        leep_scores[model_name] = LEEP(dummy_dist,targets)
        ds_num += 1

    return leep_scores
    # pcc,p = scipy.stats.pearsonr(leep_scores, accuracies)
    # print(pcc)

def get_ensemble_acc_leep(target_model_names,leep_scores,target_loader,r=4):
    print("making ensemble combinations...")
    combinations = itertools.combinations(target_model_names,r)
    combinations = list(combinations)

    leep_scores_ensemble = np.zeros(len(combinations))
    accuracies = np.zeros(len(combinations))
    with open('./target_stanforddogs.npy','rb') as f:
        temp = np.load(f)
        accuracies[:] = temp

    for i in tqdm(range(len(combinations))):
        combination = combinations[i]
        leep_score_array = []
        
        for model_name in combination:
            split = model_name.split("_",-1)
            source_name = split[0]+"_" + split[2]
            leep_score_array.append(leep_scores[source_name])
            
        leep_score_array = np.asarray(leep_score_array)
        leep = np.sum(leep_score_array)

        leep_scores_ensemble[i] = leep
        print("Combination: ",i,", Accuracy: ",accuracies[i],", LEEP Score: ", leep)

    print("Complete!")   

    return leep_scores_ensemble, accuracies


def get_model_accuracy(model,target_loader):
    total = 0
    correct = 0
    model.to(device)
    model.eval()
    with torch.no_grad():
        for batch_idx,(x,y) in enumerate(target_loader):
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            # out = torch.tensor(out,dtype=int).to(device)
            _,pred = torch.max(out.data,1)
            total += y.shape[0]
            correct += (y==pred).sum().item()

    accuracy = 100*correct/total
    return accuracy


if __name__ == "__main__":

    source_model_names, source_models = initialize_source_models()
    finetuned_model_names = initialize_transferred_models()

    transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    ])

    target_ds = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs',train=True,transform=transform)
    target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)

    leep_scores = calc_leep_scores(source_models, target_loader, target_ds)

    ensem_leep, acc = get_ensemble_acc_leep(finetuned_model_names,leep_scores,target_loader,4) 

    pcc,p = scipy.stats.pearsonr(ensem_leep, acc)
    print(pcc)

    







    


