#!/usr/bin/env python
# coding=utf-8
#!/usr/bin/env python
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.autograd import Variable
import numpy as np
import time
import os
import copy
import argparse
from PIL import Image
from scipy.spatial.distance import cdist
from sklearn.metrics import confusion_matrix
from utils_pytorch import *

def compute_features(tg_feature_model, evalloader, num_samples, num_features, device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tg_feature_model.eval()

    features = np.zeros([num_samples, num_features])
    start_idx = 0
    with torch.no_grad():
        for inputs, _, targets in evalloader:
            inputs = inputs.to(device)
            features[start_idx:start_idx+inputs.shape[0], :] = np.squeeze(tg_feature_model(inputs).data.cpu().numpy())
            start_idx = start_idx+inputs.shape[0]
    assert(start_idx==num_samples)
    return features

def compute_feats(tg_model, evalloader, num_samples, num_features, device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tg_model.eval()

    features = np.zeros([num_samples, num_features])
    start_idx = 0
    with torch.no_grad():
        for inputs, _, targets in evalloader:
            inputs = inputs.to(device)
            features[start_idx:start_idx+inputs.shape[0], :] = np.squeeze(tg_model(inputs,return_feats=True)[2].data.cpu().numpy())
            start_idx = start_idx+inputs.shape[0]
    assert(start_idx==num_samples)
    return features