import itertools
import numpy as np
import torch
import flowers102
import torch.nn as nn
import torchvision
import os
import pretrainedmodels
import torch.nn.functional as F
from tqdm import tqdm
import scipy.stats
from LEEP import LEEP


BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class ResNet34(nn.Module):
    def __init__(self, pretrained):
        super(ResNet34, self).__init__()
        if pretrained is True:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
        else:
            self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)
        
        self.l0 = nn.Linear(512, 101)
        self.dropout = nn.Dropout2d(0.4)

    def forward(self, x):
        # get the batch size only, ignore (c, h, w)
        batch, _, _, _ = x.shape
        x = self.model.features(x)
        x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
        x = self.dropout(x)
        l0 = self.l0(x)
        return l0
    
def initialize_transferred_models() -> list:
    models_list  = []
    models_list.append("imagenet_flowers_densenet201")
    models_list.append("stanforddogs_flowers_densenet201")
    models_list.append("oxfordpets_flowers_densenet201")
    models_list.append("imagenet_flowers_resnet101")
    models_list.append("oxfordpets_flowers_resnet101")
    models_list.append("cub_flowers_resnet18")
    models_list.append("stanforddogs_flowers_resnet18")
    models_list.append("cub_flowers_vgg19")
    models_list.append("stanforddogs_flowers_vgg19")
    models_list.append("imagenet_flowers_mobilenetv2")
    models_list.append("caltech_flowers_resnet34")

    return models_list


def initialize_source_models() -> dict:
    models_dict  = dict()
    #    # Model Architecture: DenseNet201

    # Imagenet
    model = torchvision.models.densenet201(pretrained=True)
    # in_features_final = model.classifier.in_features
    # model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-densenet201-best_scheduler.pth'))
    model_name = "imagenet_densenet201"
    models_dict[model_name] = (model,1000)

    # Stanford Dogs
    model = torchvision.models.densenet201()
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=120) 
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-densenet201-best_scheduler.pth'))
    model_name = "stanforddogs_densenet201"
    models_dict[model_name] = (model,120)

    # Oxford Pets
    model = torchvision.models.densenet201() 
    in_features_final = model.classifier.in_features
    model.classifier= torch.nn.Linear(in_features=in_features_final,out_features=37) 
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))
    model_name = "oxfordpets_densenet201"
    models_dict[model_name] = (model,37)

    #    # Model Architecture: ResNet101

    # Imagenet
    model = torchvision.models.resnet101(pretrained=True)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-resnet101-best_scheduler.pth'))
    model_name = "imagenet_resnet101"
    models_dict[model_name] = (model,1000)

    # Oxford Pets
    model = torchvision.models.resnet101() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
    model.load_state_dict(torch.load('models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))
    model_name = "oxfordpets_resnet101"
    models_dict[model_name] = (model,37)

    #    # Model Architecture: ResNet18

    # CUB200
    model = torchvision.models.resnet18()
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=200)
    ckpt = torch.load(os.path.join("models","cub_classifier-ckpt.pth"))
    model.load_state_dict(ckpt['model'])
    model_name = "cub_resnet18"
    models_dict[model_name] = (model,200)

    # Stanford Dogs
    model = torchvision.models.resnet18() 
    in_features_final = model.fc.in_features
    model.fc = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-resnet18-best_scheduler.pth'))
    model_name = "stanforddogs_resnet18"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: VGG19

    # CUB200
    model = torchvision.models.vgg19_bn()
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
    model.load_state_dict(torch.load('models/cub-pretrained-vgg19-best_scheduler.pth'))
    model_name = "cub_vgg19"
    models_dict[model_name] = (model,200)

    # Stanford Dogs
    model = torchvision.models.vgg19_bn() 
    in_features_final = model.classifier[6].in_features
    model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
    model.load_state_dict(torch.load('models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))
    model_name = "stanforddogs_vgg19"
    models_dict[model_name] = (model,120)

    #    # Model Architecture: MobileNetv2
    # Imagenet
    model = torchvision.models.mobilenet_v2(pretrained=True)
    # in_features_final = model.classifier[1].in_features
    # model.classifier[1] = torch.nn.Linear(in_features=in_features_final,out_features=102)
    # model.load_state_dict(torch.load('transferred_models/flowers102-transfer-imagenet-mobilenetv2-best_scheduler.pth'))
    model_name = "imagenet_mobilenetv2"
    models_dict[model_name] = (model,1000)

    #    # Model Architecture: ResNet34
    model = ResNet34(pretrained=True)
    in_features_final = model.l0.in_features
    model.l0 = torch.nn.Linear(in_features=in_features_final,out_features=101)
    model.load_state_dict(torch.load('models/caltech101-pretrained.pth'))
    model_name = "caltech_resnet34"
    models_dict[model_name] = (model,101)

    return list(models_dict.keys()), models_dict


def calc_leep_scores(models, target_loader, target_ds):
    device = 'cuda'
    print(device)
    # source_dataset_name = 'imagenet'
    target_dataset_name = 'flowers102'

    # source_ds = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=transform)
    #PSA: Need to replace in one other place; Remove y-=1 for flowers
    # target_ds = flowers102.Flowers102(root='data/flowers102',split="train",transform=transform)
    
    leep_scores = dict()
    ds_num = 0
    for model_name in tqdm(models.keys()):
        
        source_model = models[model_name][0]
        source_model = source_model.to(device)
        source_model.eval()
        
        # target_ds = flowers102.Flowers102(root='data/flowers102',train=True,transform=transform)
        # target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)
            
        source_classes = models[model_name][1]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        #Remember y-=1 for flowers
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                y -= 1
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        leep_scores[model_name] = LEEP(dummy_dist,targets)
        ds_num += 1

    return leep_scores
    # pcc,p = scipy.stats.pearsonr(leep_scores, accuracies)
    # print(pcc)

def get_ensemble_acc_leep(target_model_names,leep_scores,target_loader,r=4):
    print("making ensemble combinations...")
    combinations = itertools.combinations(target_model_names,r)
    combinations = list(combinations)

    leep_scores_ensemble = np.zeros(len(combinations))
    accuracies = np.zeros(len(combinations))

    print("calculating accuracies...")

    for i in tqdm(range(len(combinations))):
        combination = combinations[i]
        preds = []
        leep_score_array = []
        for model_name in combination:
            pred = np.load("model_predictions/"+model_name+".npy",allow_pickle=True)
            preds.append(pred)

            split = model_name.split("_",-1)
            source_name = split[0]+"_" + split[2]
            leep_score_array.append(leep_scores[source_name])

        final_pred = np.zeros((len(pred),BATCH_SIZE))
        for batch in range(len(pred)):
            for j in range(BATCH_SIZE):
                if len(preds[0][batch]) < BATCH_SIZE:
                    break
                temp1 = preds[0][batch][j]
                temp2 = preds[1][batch][j]
                temp3 = preds[2][batch][j]
                temp4 = preds[3][batch][j]
                temp = (temp1+temp2+temp3+temp4)/4.0
                final_pred[batch][j] = np.argmax(temp)
        total = 0
        correct = 0
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                if batch_idx == len(final_pred) -1 :
                    x = x.to(device)
                    y = (y-1).to(device)
                    out = final_pred[batch_idx][:len(y)]
                    out = torch.tensor(out,dtype=int).to(device)
                    total += y.shape[0]
                    correct += (y==out).sum().item()
                    break

                x = x.to(device)
                y = (y-1).to(device)
                out = final_pred[batch_idx]
                # _,pred = torch.max(out.data,1)
                out = torch.tensor(out,dtype=int).to(device)
                total += y.shape[0]
                correct += (y==out).sum().item()

        accuracy = 100*correct/total
        
        leep_score_array = np.asarray(leep_score_array)
        leep = np.sum(leep_score_array)

        leep_scores_ensemble[i] = leep
        accuracies[i] = accuracy 
        print("Combination: ",i,", Accuracy: ",accuracy,", LEEP Score: ", leep)

    print("Complete!")   

    return leep_scores_ensemble, accuracies


if __name__ == "__main__":

    source_model_names, source_models = initialize_source_models()
    finetuned_model_names = initialize_transferred_models()

    transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    ])

    target_ds = flowers102.Flowers102(root='/var/data/flowers102',split="train",transform=transform)
    target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)

    leep_scores = calc_leep_scores(source_models, target_loader, target_ds)

    ensem_leep, acc = get_ensemble_acc_leep(finetuned_model_names,leep_scores,target_loader,4) 

    
    pcc,p = scipy.stats.pearsonr(ensem_leep, acc)
    print(pcc)

    







    


