import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import copy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

import os
import argparse
import random

def seed_torch(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(0)

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

transform_train = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.SVHN(
    root='./data_svhn', split='train', download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=100, shuffle=True, num_workers=2)

testset = torchvision.datasets.SVHN(
    root='./data_svhn', split='test', download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

net = models.resnet18(pretrained=True)

for param in net.parameters():
    param.requires_grad = False

net = net.to(device)
hidden_size = 64
dim2 = 8

feature_extractor = torch.nn.Sequential(*list(net.children())[:-5])

d = {}

train_features = np.zeros((73257,hidden_size,dim2,dim2))
train_labels = np.zeros((73257))
test_features = np.zeros((26032,hidden_size,dim2,dim2))
test_labels = np.zeros((26032))

for batch_idx, (inputs, targets) in enumerate(testloader):
    inputs = inputs.to(device)
    features = feature_extractor(inputs).squeeze()
    
    if batch_idx==260:
        test_features[batch_idx*100:] = features.cpu().numpy()
        test_labels[batch_idx*100:] = targets.squeeze().cpu().numpy()
        continue  

    test_features[batch_idx*100:batch_idx*100+100] = features.cpu().numpy()
    test_labels[batch_idx*100:batch_idx*100+100] = targets.squeeze().cpu().numpy()

d['resnet18_test_features'] = test_features
d['test_labels'] = test_labels


for batch_idx, (inputs, targets) in enumerate(trainloader):
    inputs = inputs.to(device)
    features = feature_extractor(inputs).squeeze()

    if batch_idx==732:
        train_features[batch_idx*100:] = features.cpu().numpy()
        train_labels[batch_idx*100:] = targets.squeeze().cpu().numpy()
        continue  

    train_features[batch_idx*100:batch_idx*100+100] = features.cpu().numpy()
    train_labels[batch_idx*100:batch_idx*100+100] = targets.squeeze().cpu().numpy()

d['resnet18_train_features'] = train_features
d['train_labels'] = train_labels



for k,v in d.items():
    print(k,v.shape)
    
torch.save(d,'svhn_without_pca_l5.pth')


