import numpy as np
from scipy.spatial import distance
import torch
from .strategy import Strategy
from .builder import STRATEGIES
from architectures.resnet_nce import euclidean_distance_func


@STRATEGIES.register_module()
class AugSingleSampling(Strategy):
    def __init__(self, dataset, net, args, logger, timestamp, n_drop=1):
        super(AugSingleSampling, self).__init__(dataset, net, args, logger, timestamp)
        self.n_drop = n_drop

    

    def query(self, n):
        _, level, metric = self.args.aug_metric.split('-')
        info_u = np.zeros(len(self.dataset.DATA_INFOS['train_u']))
        if level == 'prob':  
            
            assert metric in ['max', 'avg', 'minentropy', 'minlc', 'minmargin',
                              'avgentropy', 'avglc', 'avgmargin', 'maxentropy', 'maxlc', 'maxmargin']
            probs_ori_u = self.predict(self.clf, 'train_u', 'prob', n_drop=self.n_drop).cpu()
            for no_trial in range(self.args.aug_trials):
                self.generate_aug(1, split='train_u', aug_intensity=self.args.aug_ulb)
                if metric[:3] == 'min' and len(metric) > 3:
                    info_trial = self.predict(self.clf, ['train_u_aug_single'], metric[3:],
                                              n_drop=self.n_drop).cpu().numpy()
                    if no_trial == 0: info_u = info_trial
                    else: info_u = np.minimum(info_u, info_trial)
                elif metric[:3] == 'avg' and len(metric) > 3:
                    info_trial = self.predict(self.clf, ['train_u_aug_single'], metric[3:],
                                              n_drop=self.n_drop).cpu().numpy()
                    info_u = info_u + info_trial
                elif metric[:3] == 'max' and len(metric) > 3:
                    info_trial = self.predict(self.clf, ['train_u_aug_single'], metric[3:],
                                              n_drop=self.n_drop).cpu().numpy()
                    info_u = np.maximum(info_u, info_trial)
                else:
                    probs_aug_u = self.predict(self.clf, ['train_u_aug_single'], 'prob',
                                               n_drop=self.n_drop).cpu().numpy()
                    info_trial = distance.jensenshannon(probs_ori_u, probs_aug_u, axis=1)
                    if metric == 'max': info_u = np.maximum(info_u, info_trial)
                    elif metric == 'avg': info_u = info_u + info_trial
                    else: raise NotImplementedError
                self.del_aug(split='train_u')
            info_u = torch.tensor(info_u)
        elif level == 'fea':
            
            assert metric in ['distori', 'maxdistclscenter', 'mindistclscenter']
            fea_ori_u = self.get_embedding(self.clf, 'train_u').cpu()
            
            fea_ori_lab = self.get_embedding(self.clf, 'train').cpu()
            ori_lab = [data_info['gt_label'] for data_info in self.dataset.DATA_INFOS['train']]
            fea_ori_lab_center = []
            for i in range(len(self.dataset.CLASSES)):
                fea_ori_lab_idx_list = [idx for idx, lab in enumerate(ori_lab) if lab == i]
                if len(fea_ori_lab_idx_list) == 0:
                    
                    fea_ori_lab_center.append(
                        torch.sum(fea_ori_lab, axis=0).unsqueeze(0) / len(fea_ori_lab))
                else:
                    fea_ori_lab_center.append(
                        torch.sum(fea_ori_lab[fea_ori_lab_idx_list], axis=0).unsqueeze(0) / len(fea_ori_lab_idx_list))
            fea_ori_lab_center = torch.cat(fea_ori_lab_center)
            
            for no_trial in range(self.args.aug_trials):
                self.generate_aug(1, split='train_u', aug_intensity=self.args.aug_ulb)
                fea_trial = self.get_embedding(self.clf, ['train_u_aug_single']).cpu()
                if metric == 'distori':
                    info_trial = torch.sum((fea_ori_u - fea_trial) ** 2, axis=1).numpy()
                    info_u = np.maximum(info_u, info_trial)
                else:
                    dist_mat = euclidean_distance_func(fea_trial, fea_ori_lab_center)
                    if metric == 'maxdistclscenter':
                        info_u = np.maximum(info_u, torch.min(dist_mat, 1)[0])
                    elif metric == 'mindistclscenter':
                        if no_trial == 0: info_u = torch.min(dist_mat, 1)[0]
                        else: info_u = np.minimum(info_u, torch.min(dist_mat, 1)[0])
                    
                    else: raise NotImplementedError
            info_u = torch.tensor(info_u)
        else:
            raise Exception(f'level {level} does not exists!')
        return info_u.sort()[1][:n]
