import numpy as np
import os
import itertools
import sys
import torch
import torchvision
import scipy
import scipy.stats
from tqdm import tqdm
import argparse
from LEEP import LEEP
import random
import flowers102
# import caltech101.caltech_dataset as caltech101
import pretrainedmodels 
from resnet34_caltech import ResNet34 
import cub_200
import stanford_dogs
import oxford_pets
import torch.nn as nn

BATCH_SIZE = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_dataset(dataset_name, preprocess_fn):
    if(dataset_name=='cifar100'):
        dataset = torchvision.datasets.CIFAR100(root='/var/data/cifar100/',download=True,transform=preprocess_fn)
    elif(dataset_name=='cifar10'):
        dataset = torchvision.datasets.CIFAR10(root='/var/data/cifar10',download=True,transform=preprocess_fn)
    elif(dataset_name=='fashionmnist'):
        dataset = torchvision.datasets.FashionMNIST(root='/var/data/fashionmnist',download=True,transform=preprocess_fn)
    elif(dataset_name=='caltech101'):
        dataset = caltech_dataset.Caltech(root='/var/data/caltech101',split='train',transform=preprocess_fn)
    elif(dataset_name=='flowers102'):
        dataset = flowers102.Flowers102(root='/var/data/flowers102',split='train',transform=preprocess_fn)
    elif(dataset_name=='stanford_cars'):
        dataset = stanford_cars.StanfordCars(root='/var/data/stanford_cars',split='train',transform=preprocess_fn)
    elif(dataset_name=='tiny-imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/tiny-imagenet/tiny-imagenet-200/train',transform=preprocess_fn)
    elif(dataset_name=='imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=preprocess_fn)
    elif(dataset_name=='cub200'):
        dataset = cub_200.CUB200(root='/var/data/cub_200/',train=True,transform=preprocess_fn)
    elif(dataset_name=='stanford_dogs'):
        dataset = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs',train=True,transform=preprocess_fn)
    elif(dataset_name=='chest_xray'):
        dataset = chest_xray_dataset.ChestXRayDataset(root='/var/data/chest_xray',train=True,transform=preprocess_fn)
    elif(dataset_name=='pets'):
        dataset = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=preprocess_fn)
    elif(dataset_name=='imagenette'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenette/imagenette2/train',transform=preprocess_fn)
    else:
        print('Dataset not recognized')
        raise NotImplementedError
        
    return dataset

def get_model_transform(model_name):
    if(model_name=='pets_resnet101'):
        model = torchvision.models.resnet101(pretrained=True)
        in_features_final = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
        model.load_state_dict(torch.load('./models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model=='pets_densenet201'):
        model = torchvision.models.densenet201(pretrained=True)
        in_features_final = model.classifier.in_features
        model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=37)
        model.load_state_dict(torch.load('./models/oxfordpets-pretrained-densenet201-best_scheduler.pth'))

        #transform
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='imagenet_resnet50'):
        model = torchvision.models.resnet50(pretrained=True)
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
        return model,transform
    
    if(model_name=='imagenet_densenet201'):
        model = torchvision.models.densenet201(pretrained=True)
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
        return model,transform
    
    if(model_name=='imagenet_mobilenetv2'):
        model = torchvision.models.mobilenetv2(pretrained=True)
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
        return model,transform
    
    if(model_name=='caltech101_resnet34'):
        class ResNet34(nn.Module):
            def __init__(self, pretrained):
                super(ResNet34, self).__init__()
                if pretrained is True:
                    self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
                else:
                    self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)

                self.l0 = nn.Linear(512, 101)
                self.dropout = nn.Dropout2d(0.4)

            def forward(self, x):
                # get the batch size only, ignore (c, h, w)
                batch, _, _, _ = x.shape
                x = self.model.features(x)
                x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
                x = self.dropout(x)
                l0 = self.l0(x)
                return l0

        model = ResNet34(pretrained=True).to(device)
        model.load_state_dict(torch.load('./models/caltech101-pretrained.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
        ])
        
        return model, transform
    
    if(model_name=='cub200_vgg19'):
        model = torchvision.models.vgg19_bn(pretrained=True)
        in_features_final = model.classifier[6].in_features
        model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
        model.load_state_dict(torch.load('./models/cub-pretrained-vgg19-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='cub200_resnet18'):
        model = torchvision.models.resnet18(pretrained=True)
        last_layer_input = model.fc.in_features
        replaced_last_layer = nn.Linear(in_features=last_layer_input, out_features=200, bias=True)
        model.fc = replaced_last_layer
        ckpt = torch.load(os.path.join("models","cub_classifier-ckpt.pth"))
        model.load_state_dict(ckpt['model'])

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        return model,transform
    
    if(model_name=='flowers102_resnet101'):
        features_final = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=in_features_final,out_features=102)
        model.load_state_dict(torch.load('./models/flowers102-pretrained-resnet101-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='flowers102_densenet201'):
        model = torchvision.models.densenet201(pretrained=True)
        in_features_final = model.classifier.in_features
        model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=102)
        model.load_state_dict(torch.load('./models/flowers102-pretrained-densenet201-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='stanford_dogs_vgg19'):
        model = torchvision.models.vgg19_bn(pretrained=True)
        in_features_final = model.classifier[6].in_features
        model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
        model.load_state_dict(torch.load('./models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='stanford_dogs_resnet18'):
        model = torchvision.models.resnet18(pretrained=True)
        in_features_final = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=in_features_final,out_features=120)
        model.load_state_dict(torch.load('./models/stanforddogs-pretrained-resnet18-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='stanford_dogs_densenet201'):
        model = torchvision.models.densenet201(pretrained=True)
        in_features_final = model.classifier.in_features
        model.classifier = torch.nn.Linear(in_features=in_features_final,out_features=120)
        model.load_state_dict(torch.load('./models/stanforddogs-pretrained-densenet201-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
        
    raise NotImplementedError


    
def main():
    print(device)
    target_dataset_name = 'flowers102'
    k = 4
    models = [
        'imagenet_resnet50',
        'imagenet_densenet201',
        'imagenet_mobilenetv2',
        'caltech101_resnet34',
        'cub200_resnet18',
        'cub200_vgg19',
        'stanford_dogs_vgg19',
        'stanford_dogs_resnet18',
        'stanford_dogs_densenet201',
        # 'flowers102_resnet101',
        # 'flowers102_densenet201',
        'pets_densenet201',
        'pets_resnet101',
    ]
    
    num_classes = {
        'imagenet_resnet50': 1000,
        'imagenet_densenet201' :1000,
        'imagenet_mobilenetv2' :1000,
        'caltech101_resnet34': 101,
        'cub200_resnet18': 200,
        'cub200_vgg19': 200,
        'stanford_dogs_vgg19': 120,
        'stanford_dogs_resnet18': 120,
        'stanford_dogs_densenet201': 120,
        # 'flowers102_resnet101': 102,
        # 'flowers102_densenet201': 102,
        'pets_densenet201': 37,
        'pets_resnet101': 37,
    }

    #PSA: Need to replace in one other place; Remove y-=1 for flowers
    model_leep_scores = np.zeros((len(models)))
    model_preds = np.zeros((len(models)))
    for ds_num,model_name in tqdm(enumerate(models),total=len(models)):
        source_model, preprocess_fn = get_model_transform(model_name)
        source_model = source_model.to(device)
        source_model.eval()
        
        train_ds = get_dataset(target_dataset_name,preprocess_fn)
        target_loader = torch.utils.data.DataLoader(target_ds, shuffle=False, num_workers=2, batch_size=BATCH_SIZE)
            
        source_classes = num_classes[model]
        dummy_dist = np.zeros((len(target_ds),source_classes))
        targets = []
        
        #Remember y-=1 for flowers
        with torch.no_grad():
            for batch_idx,(x,y) in enumerate(target_loader):
                curr_batch_size = len(x)
                x = x.to(device)
                y -= 1
                out = source_model(x)
                out = torch.nn.functional.softmax(out,dim=1)
                targets.extend(list(y))
                dummy_dist[batch_idx*BATCH_SIZE:batch_idx*BATCH_SIZE + curr_batch_size] = out.detach().cpu().numpy()

        targets = np.array(targets)
        model_leep_scores[ds_num] = LEEP(dummy_dist,targets)
        
    final_leep_scores = []
    for x in itertools.combinations([i for i in range(len(models))], k):
        x = list(x)
        final_leep_scores.append(model_leep_scores[x].sum())
    
    with open(f'./{flowers102}_ensemble_leep_scores.npy','wb') as f:
        np.save(f, np.array(final_leep_scores))
        

if __name__ == '__main__':
    main()