import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from torch.utils.data.dataset import Dataset

import os
import argparse

import random

from numpy import load

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

class PathMNIST(Dataset):
    def __init__(self, train, transform):
        super(PathMNIST, self).__init__()
        if train:
            seed_torch(0)
        else:
            seed_torch(1)
        self.train = train
        self.transform = transform
        self.num_classes = 9
        data = load('pathmnist.npz')

        if self.train:
            self.data_size = 89996
            train_images = data['train_images']
            train_labels = data['train_labels']

            train_images = torch.Tensor(train_images)
            train_labels = torch.Tensor(train_labels)

            N = train_images.shape[0]
            dim1 = train_images.shape[1]
            dim3 = train_images.shape[3]
            
            train_images_reshape = torch.zeros((N,dim3,dim1,dim1))
            for i in range(dim3):
                train_images_reshape[:,i,:,:] = train_images[:,:,:,i]
                # train_images_reshape[:,i,:,:] = train_images[:,:,:]

            self.data = train_images_reshape
            # self.data = train_images
            self.labels = train_labels
            self.labels = self.labels.type(torch.LongTensor)
        else:
            self.data_size = 7180
            test_images = data['test_images']
            test_labels = data['test_labels']

            test_images = torch.Tensor(test_images)
            test_labels = torch.Tensor(test_labels)

            N = test_images.shape[0]
            dim1 = test_images.shape[1]
            dim3 = test_images.shape[3]
            
            test_images_reshape = torch.zeros((N,dim3,dim1,dim1))
            for i in range(dim3):
                test_images_reshape[:,i,:,:] = test_images[:,:,:,i]
                # test_images_reshape[:,i,:,:] = test_images[:,:,:]

            self.data = test_images_reshape
            # self.data = test_images
            self.labels = test_labels
            self.labels = self.labels.type(torch.LongTensor)
        
    def __getitem__(self, i):
        return i, self.data[i], self.labels[i]
        
    def __len__(self):
        return self.data.shape[0]


device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

trainset = PathMNIST(train = True, transform = transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=True, num_workers=2)

testset = PathMNIST(train = False, transform = transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=True, num_workers=2)

net = models.resnet18(pretrained=True)

for param in net.parameters():
    param.requires_grad = False

net = net.to(device)
hidden_size = 64
dim2 = 7
n_components = 64

feature_extractor = torch.nn.Sequential(*list(net.children())[:-5])

d = {}

train_features = np.zeros((89996,hidden_size,dim2,dim2))
train_labels = np.zeros((89996))
test_features = np.zeros((7180,hidden_size,dim2,dim2))
test_labels = np.zeros((7180))


for batch_idx, (_, inputs, targets) in enumerate(testloader):
    inputs = inputs.to(device)
    features = feature_extractor(inputs).squeeze()

    if batch_idx==71:
        test_features[batch_idx*100:] = features.cpu().numpy()
        test_labels[batch_idx*100:] = targets.squeeze().cpu().numpy()
        continue  
    
    test_features[batch_idx*100:batch_idx*100+100] = features.cpu().numpy()
    test_labels[batch_idx*100:batch_idx*100+100] = targets.squeeze().cpu().numpy()

d['resnet18_test_features'] = test_features
d['test_labels'] = test_labels

for batch_idx, (_, inputs, targets) in enumerate(trainloader):
    inputs = inputs.to(device)
    features = feature_extractor(inputs).squeeze()

    if batch_idx==899:
        train_features[batch_idx*100:] = features.cpu().numpy()
        train_labels[batch_idx*100:] = targets.squeeze().cpu().numpy()
        continue  
    
    train_features[batch_idx*100:batch_idx*100+100] = features.cpu().numpy()
    train_labels[batch_idx*100:batch_idx*100+100] = targets.squeeze().cpu().numpy()

d['resnet18_train_features'] = train_features
d['train_labels'] = train_labels

for k,v in d.items():
    print(k,v.shape)
    
torch.save(d,'pathmnist_without_pca_l5.pth')