import numpy as np
from sklearn.decomposition import PCA
from scipy.spatial import distance
import os
import torch
# # import torchvision
# import flowers102
# # import caltech101.caltech_dataset as caltech101
# # from resnet34_caltech import ResNet34 
# import cub_200
# import stanford_dogs
# import oxford_pets
# import caltech_dataset
from tqdm import tqdm 

BATCH_SIZE = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# transform = torchvision.transforms.Compose([
#     torchvision.transforms.Resize(256),
#     torchvision.transforms.CenterCrop(224),
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                          std=[0.229, 0.224, 0.225])
# ])

def get_gaussian(features):
    mean = np.mean(features,axis=0)
    cov = np.cov(features,rowvar=False)
    return mean,cov


def classwise_hard_samples(logits,targets,K=2.0):
    hard_samples = {}
    for target in range(len(np.unique(np.asarray(targets)))):
        idx = np.zeros(len(targets),dtype=bool)
        for i in range(len(targets)):
            if targets[i] == target:
                idx[i] = True
        classwise_logits = logits[idx]
        mean,cov = get_gaussian(classwise_logits)
        icov = np.linalg.inv(cov)
        hard_idx = np.zeros(len(targets),dtype=bool)
        for i in range(len(targets)):
            if targets[i] == target:
                dist = distance.mahalanobis(logits[i],mean,icov)
                if dist > K:
                    hard_idx[i] = True
        hard_samples[str(target)] = hard_idx

    return hard_samples

def get_hard_subset(logits, targets, K):
    target_len, source_classes = logits.shape

    classwise_hard = classwise_hard_samples(logits, targets, K=K)
    subset_idx = None
    for k,v in classwise_hard.items():
        if subset_idx is None:
            subset_idx = v
        else:
            subset_idx = np.logical_or(v,subset_idx)
    
    indices = np.asarray(range(len(targets)))
    indices = indices[subset_idx]

    return indices

def get_hardness_ordering(logits, targets):
    means = {}
    icov = {}
    for target in range(len(np.unique(np.asarray(targets)))):
        idx = np.zeros(len(targets),dtype=bool)
        for i in range(len(targets)):
            if targets[i] == target:
                idx[i] = True
        classwise_logits = logits[idx]
        mean,cov = get_gaussian(classwise_logits)
        means[target] = mean
        icov[target] = np.linalg.inv(cov)
        
    dists = np.zeros(len(targets))
    for i in range(len(targets)):
        curr_target = targets[i]
        dists[i] = distance.mahalanobis(logits[i], means[target], icov[target])

    return np.argsort(dists)