import numpy as np
from .strategy import Strategy
from .builder import STRATEGIES
import copy


def distance_matrix(features1, features2):
    G = np.dot(features1, features2.T)
    H1 = np.einsum('ij,ij->i', features1, features1)
    H1 = np.tile(H1, (features2.shape[0], 1)).T
    H2 = np.einsum('ij,ij->i', features2, features2)
    H2 = np.tile(H2, (features1.shape[0], 1))
    D = H1 + H2 - G * 2
    return D


@STRATEGIES.register_module()
class CoreSet(Strategy):
    def __init__(self, dataset, net, args, logger, timestamp):
        super(CoreSet, self).__init__(dataset, net, args, logger, timestamp)

    def query(self, n):
        if self.args.aug_ratio > 0:
            self.generate_aug(2, split='train_u')
        if self.args.mixup_ratio > 0:
            self.generate_mixup(2, split='train_u')
        split_list = ['train_u']
        ulb_total_length = len(self.dataset.DATA_INFOS['train_u'])
        if 'train_u_aug_single' in self.dataset.DATA_INFOS.keys():
            split_list.append('train_u_aug_single')
            ulb_total_length += len(self.dataset.DATA_INFOS['train_u_aug_single'])
        if 'train_u_aug_mixup' in self.dataset.DATA_INFOS.keys():
            split_list.append('train_u_aug_mixup')
            ulb_total_length += len(self.dataset.DATA_INFOS['train_u_aug_mixup'])

        selected_samples = []
        temp_idxs_lb = np.zeros(ulb_total_length + len(self.dataset.DATA_INFOS['train']), dtype=bool)
        temp_idxs_lb[ulb_total_length:] = True
        temp_idxs_ulb = np.arange(ulb_total_length)
        temp_idxs_ulb_rank = np.arange(len(temp_idxs_ulb))  
        features_past = self.get_embedding(self.clf,
                                           split='train').cpu().numpy()
        features_now = self.get_embedding(self.clf,
                                          split=self.get_ulb_list()).cpu().numpy()  
        
        D = distance_matrix(features_now, np.vstack([features_now, features_past]))  
        for i in range(n):
            sub_distance_matrix = D[temp_idxs_ulb_rank, :][:, temp_idxs_lb]  
            min_distances = np.amin(sub_distance_matrix, axis=1)  
            min_dist_rank = np.argsort(min_distances)  
            selected = np.argmax(min_dist_rank)
            selected_idx = temp_idxs_ulb[selected]
            selected_samples.append(selected_idx)
            temp_idxs_lb[selected_idx] = True
            temp_idxs_ulb = np.delete(temp_idxs_ulb, selected)
            temp_idxs_ulb_rank = np.delete(temp_idxs_ulb_rank, selected)
            selected_samples.append(selected_idx)

        return np.array(selected_samples, dtype=int)
