import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from tqdm import tqdm
import os
import sys
import argparse
import pretrainedmodels
import flowers102
import stanford_cars
import stanford_dogs
import cub_200
import chest_xray_dataset
import oxford_pets
import caltech_dataset

parser = argparse.ArgumentParser()
parser.add_argument('--dataset',dest='dataset',type=str)
parser.set_defaults(dataset=None)
args = parser.parse_args()

if(args.dataset):
    ds_name = args.dataset
else:
    print('No Dataset mentioned')
    sys.exit()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

def get_dataset(dataset_name, preprocess_fn):
    if(dataset_name=='cifar100'):
        dataset = torchvision.datasets.CIFAR100(root='/var/data/cifar100/',download=True,transform=preprocess_fn)
    elif(dataset_name=='cifar10'):
        dataset = torchvision.datasets.CIFAR10(root='/var/data/cifar10',download=True,transform=preprocess_fn)
    elif(dataset_name=='fashionmnist'):
        dataset = torchvision.datasets.FashionMNIST(root='/var/data/fashionmnist',download=True,transform=preprocess_fn)
    elif(dataset_name=='caltech101'):
        dataset = caltech_dataset.Caltech(root='/var/data/caltech101',split='train',transform=preprocess_fn)
    elif(dataset_name=='flowers102'):
        dataset = flowers102.Flowers102(root='/var/data/flowers102',split='train',transform=preprocess_fn)
    elif(dataset_name=='stanford_cars'):
        dataset = stanford_cars.StanfordCars(root='/var/data/stanford_cars',split='train',transform=preprocess_fn)
    elif(dataset_name=='tiny-imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/tiny-imagenet/tiny-imagenet-200/train',transform=preprocess_fn)
    elif(dataset_name=='imagenet'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenet/subset_imgs/train',transform=preprocess_fn)
    elif(dataset_name=='cub200'):
        dataset = cub_200.CUB200(root='/var/data/cub_200/',train=True,transform=preprocess_fn)
    elif(dataset_name=='stanford_dogs'):
        dataset = stanford_dogs.StanfordDogs(root='/var/data/stanford_dogs',train=True,transform=preprocess_fn)
    elif(dataset_name=='chest_xray'):
        dataset = chest_xray_dataset.ChestXRayDataset(root='/var/data/chest_xray',train=True,transform=preprocess_fn)
    elif(dataset_name=='pets'):
        dataset = oxford_pets.OxfordIIITPets(root='/var/data/pets',split='trainval',transform=preprocess_fn)
    elif(dataset_name=='imagenette'):
        dataset = torchvision.datasets.ImageFolder(root='/var/data/imagenette/imagenette2/train',transform=preprocess_fn)
    else:
        print('Dataset not recognized')
        raise NotImplementedError
        
    return dataset
    
def get_model_transform(model_name):
    if(model_name=='pets_resnet101'):
        model = torchvision.models.resnet101(pretrained=True)
        in_features_final = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=in_features_final,out_features=37)
        model.load_state_dict(torch.load('./models/oxfordpets-pretrained-resnet101-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='imagenet_resnet50'):
        model = torchvision.models.resnet50(pretrained=True)
        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize(256),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        
        return model,transform
    
    if(model_name=='caltech101_resnet34'):
        class ResNet34(nn.Module):
            def __init__(self, pretrained):
                super(ResNet34, self).__init__()
                if pretrained is True:
                    self.model = pretrainedmodels.__dict__['resnet34'](pretrained='imagenet')
                else:
                    self.model = pretrainedmodels.__dict__['resnet34'](pretrained=None)

                self.l0 = nn.Linear(512, 101)
                self.dropout = nn.Dropout2d(0.4)

            def forward(self, x):
                # get the batch size only, ignore (c, h, w)
                batch, _, _, _ = x.shape
                x = self.model.features(x)
                x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
                x = self.dropout(x)
                l0 = self.l0(x)
                return l0

        model = ResNet34(pretrained=True).to(device)
        model.load_state_dict(torch.load('./models/caltech101-pretrained.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
        ])
        
        return model, transform
    
    if(model_name=='cub200_vgg19'):
        model = torchvision.models.vgg19_bn(pretrained=True)
        in_features_final = model.classifier[6].in_features
        model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=200)
        model.load_state_dict(torch.load('./models/cub-pretrained-vgg19-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='flowers102_resnet101'):
        model = torchvision.models.resnet101(pretrained=True)
        in_features_final = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=in_features_final,out_features=102)
        model.load_state_dict(torch.load('./models/flowers102-pretrained-resnet101-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
    
    if(model_name=='stanford_dogs_vgg19'):
        model = torchvision.models.vgg19_bn(pretrained=True)
        in_features_final = model.classifier[6].in_features
        model.classifier[6] = torch.nn.Linear(in_features=in_features_final,out_features=120)
        model.load_state_dict(torch.load('./models/stanforddogs-pretrained-vgg19-best_scheduler.pth'))

        transform = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        return model,transform
        
    raise NotImplementedError
        

def get_layer_output_shape(m_name, layer, ds_len):
    if('vgg19' in m_name):
        if(layer==1):
            return (ds_len,256,56,56)
        elif(layer==2):
            return (ds_len,512,28,28)
        elif(layer==3):
            return (ds_len,512,14,14)
        elif(layer==4):
            return (ds_len,512,7,7)
    
    if('resnet34' in m_name):
        if(layer==1):
            return (ds_len,128,28,28)
        elif(layer==2):
            return (ds_len,256,14,14)
        elif(layer==3):
            return (ds_len,512,7,7)
        
    if('resnet101' in m_name):
        if(layer==1):
            return (ds_len,512,28,28)
        elif(layer==2):
            return (ds_len,1024,14,14)
        elif(layer==3):
            return (ds_len,1024,14,14)
        elif(layer==4):
            return (ds_len,2048,7,7)
        
    if('resnet50' in m_name):
        if(layer==1):
            return (ds_len,512,28,28)
        elif(layer==2):
            return (ds_len,1024,14,14)
        elif(layer==3):
            return (ds_len,1024,14,14)
        elif(layer==4):
            return (ds_len,2048,7,7)
        
        
        
    raise NotImplementedError

def get_model_layer_nums(model_name):
    if 'resnet34' in model_name:
        return 3
    if 'vgg19' in model_name:
        return 4
    if 'resnet101' in model_name:
        return 4
    if 'resnet50' in model_name:
        return 4
    
    raise NotImplementedError
    
models = [
#     'pets_resnet101',
#     'caltech101_resnet34',
#     'imagenet_resnet50',
#     'cub200_vgg19',
    'stanford_dogs_vgg19',
#     'flowers102_resnet101'
]
batch_size = 32

for m_name in tqdm(models):
    print(m_name)
    model, preprocess_fn = get_model_transform(m_name)
    num_layers = get_model_layer_nums(m_name)
    
    dataset = get_dataset(ds_name,preprocess_fn)
    data_loader = torch.utils.data.DataLoader(dataset,batch_size=batch_size)

    num_batches = len(data_loader)
    curr_batch_start = 0
    curr_batch_size = batch_size
    
    layer_outputs = {
        f'layer{i}':np.memmap(f'/mnt2/ensemble_activations/{m_name}_{ds_name}_layer{i}_acts.dat',shape=get_layer_output_shape(m_name, i, len(dataset)),mode='w+',dtype=np.float32) for i in range(1,num_layers+1)
    }

    def forward_hook_builder(layer_name):
        def save_hook(self,input_tensor,output_tensor):
            global curr_batch_start
            global batch_size
            layer_outputs[layer_name][curr_batch_start:curr_batch_start+curr_batch_size] = np.copy(output_tensor.detach().cpu().numpy())

        return save_hook

    if 'vgg19' in m_name:
        model.features[25].register_forward_hook(forward_hook_builder('layer1'))
        model.features[38].register_forward_hook(forward_hook_builder('layer2'))
        model.features[51].register_forward_hook(forward_hook_builder('layer3'))
        model.avgpool.register_forward_hook(forward_hook_builder('layer4'))
    elif 'resnet34' in m_name:
        model.model.layer2[3].register_forward_hook(forward_hook_builder('layer1'))
        model.model.layer3[5].register_forward_hook(forward_hook_builder('layer2'))
        model.model.layer4[2].register_forward_hook(forward_hook_builder('layer3'))
    elif 'resnet101' in m_name:
        model.layer2[3].register_forward_hook(forward_hook_builder('layer1'))
        model.layer3[11].register_forward_hook(forward_hook_builder('layer2'))
        model.layer3[22].register_forward_hook(forward_hook_builder('layer3'))
        model.layer4[2].register_forward_hook(forward_hook_builder('layer4'))
    elif 'resnet50' in m_name:
        model.layer2[3].register_forward_hook(forward_hook_builder('layer1'))
        model.layer3[3].register_forward_hook(forward_hook_builder('layer2'))
        model.layer3[5].register_forward_hook(forward_hook_builder('layer3'))
        model.layer4[1].register_forward_hook(forward_hook_builder('layer4'))
                           

    model = model.to(device)
    model.eval()
    for batch_idx,(img, data) in tqdm(enumerate(data_loader),total=num_batches):
        img = img.to(device)
        curr_batch_size = img.shape[0]
        model(img)
        curr_batch_start += curr_batch_size