import numpy as np
from scipy.spatial import distance
import torch
from .strategy import Strategy
from .builder import STRATEGIES
import random
from datasets.base_dataset import BaseDataset
from architectures.resnet_nce import euclidean_distance_func


def generate_ps_lab(lab_u, lab_l, probs_mix_ul, pace):
    same_class_temp = (lab_u == lab_l)
    probs_mix_ps = torch.zeros_like(probs_mix_ul)
    _, indices_temp = torch.topk(probs_mix_ul, 2)
    indices_temp_0 = indices_temp[:, 0].unsqueeze(1)
    indices_temp_1 = indices_temp[:, 1].unsqueeze(1)
    probs_mix_ps[~same_class_temp].scatter_(1, indices_temp_0[~same_class_temp], 1 - pace)
    probs_mix_ps[~same_class_temp].scatter_(1, indices_temp_1[~same_class_temp], pace)
    probs_mix_ps[same_class_temp, lab_u[same_class_temp]] = 1.
    return probs_mix_ps


def prob_metric_performer(metric: str, dataset: BaseDataset, probs_ori_full: torch.Tensor,
                          probs_mix_ul: torch.Tensor, select_lab_list: list, pace: float,
                          probs_mix_ul2: torch.Tensor = None):
    
    
    
    probs_mix_ps = probs_ori_full[dataset.INDEX_ULB]
    
    if metric == 'diventropy':
        label_2_idx = {i: [] for i in range(len(dataset.CLASSES))}  
        global_lab_indices = np.arange(len(dataset.DATA_INFOS['train_full']))[dataset.INDEX_LB]
        for idx, data_elem in enumerate(dataset.DATA_INFOS['train']):
            temp_gt_label = data_elem['gt_label']
            
            if type(temp_gt_label) != int:
                temp_gt_label = temp_gt_label.item()
            label_2_idx[temp_gt_label].append(global_lab_indices[idx])
        lab_all = torch.argmax(probs_ori_full, 1)
        lab_u = torch.argmax(probs_ori_full[dataset.INDEX_ULB], 1)
        lab_l_to_u = lab_all[select_lab_list]
        probs_mix_ps = generate_ps_lab(lab_u, lab_l_to_u, probs_mix_ul, pace)
        info_trial = distance.jensenshannon(probs_mix_ul, probs_mix_ps, axis=1)
    elif metric == 'div':
        info_trial = distance.jensenshannon(probs_mix_ul, probs_mix_ps, axis=1)
    elif metric == 'firstorder':
        info_trial = torch.sum(probs_mix_ps * torch.abs(probs_mix_ul - probs_mix_ps), axis=1).numpy() / pace
    elif metric == 'secondorder':
        assert probs_mix_ul2 is not None
        info_trial = torch.sum(probs_mix_ps * torch.abs(probs_mix_ul2 - 2 * probs_mix_ul + probs_mix_ps), axis=1).numpy()
    elif metric == 'curorder':
        assert probs_mix_ul2 is not None
        info_trial_first = (probs_mix_ul - probs_mix_ps) / pace
        info_trial_second = torch.abs(probs_mix_ul2 - 2 * probs_mix_ul + probs_mix_ps) / (pace ** 2)
        info_trial_curvature = info_trial_second / ((1. + info_trial_first ** 2) ** 1.5)
        info_trial = torch.sum(probs_mix_ps * info_trial_curvature, axis=1).numpy()
    else:
        info_trial = None
        raise NotImplementedError
    return info_trial


def feat_metric_performer(metric: str, dataset: BaseDataset, feats_ori_full: torch.Tensor,
                          feats_mix_ul: torch.Tensor, pace: float,
                          feats_mix_ul2: torch.Tensor = None):
    
    feats_ori_u = feats_ori_full[dataset.INDEX_ULB]
    
    if metric == 'discenter':
        feats_ori_lab = feats_ori_full[dataset.INDEX_LB]
        ori_lab = [data_info['gt_label'] for data_info in dataset.DATA_INFOS['train']]
        fea_ori_lab_center = []
        for i in range(len(dataset.CLASSES)):
            fea_ori_lab_idx_list = [idx for idx, lab in enumerate(ori_lab) if lab == i]
            if len(fea_ori_lab_idx_list) == 0:
                
                fea_ori_lab_center.append(torch.sum(feats_ori_lab, axis=0).unsqueeze(0) / len(feats_ori_lab))
            else:
                fea_ori_lab_center.append(
                    torch.sum(feats_ori_lab[fea_ori_lab_idx_list], axis=0).unsqueeze(0) / len(fea_ori_lab_idx_list))
        fea_ori_lab_center = torch.cat(fea_ori_lab_center)
        dist_mat = euclidean_distance_func(feats_ori_u, fea_ori_lab_center)
        info_trial = torch.min(dist_mat, 1)[0]
    elif metric == 'dis':
        info_trial = torch.norm(feats_ori_u - feats_mix_ul, dim=1).numpy()
    elif metric == 'firstorder':
        info_trial = torch.norm(feats_ori_u - feats_mix_ul, dim=1).numpy()
    elif metric == 'secondorder':
        assert feats_mix_ul2 is not None
        info_trial = torch.norm(feats_mix_ul2 - 2 * feats_mix_ul + feats_ori_u, dim=1).numpy()
    elif metric == 'curorder':
        assert feats_mix_ul2 is not None
        info_trial_first = (feats_mix_ul - feats_ori_u) / pace
        info_trial_second = torch.abs(feats_mix_ul2 - 2 * feats_mix_ul + feats_ori_u) / (pace ** 2)
        info_trial_curvature = info_trial_second / ((1. + info_trial_first ** 2) ** 1.5)
        info_trial = torch.norm(info_trial_curvature, dim=1).numpy()
    else:
        info_trial = None
        raise NotImplementedError
    return info_trial


@STRATEGIES.register_module()
class AugMixSampling(Strategy):
    
    def __init__(self, dataset, net, args, logger, timestamp, n_drop=1):
        super(AugMixSampling, self).__init__(dataset, net, args, logger, timestamp)
        self.n_drop = n_drop

    

    def generate_mixed_output(self, global_ulab_indices, select_lab_list):
        self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
            dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b, lam=1 - self.args.pace)
            for i, idx_b in enumerate(select_lab_list)]
        probs_mix_ul = self.predict(self.clf, ['train_full_aug_mixup'], 'prob', n_drop=self.n_drop).cpu()
        del self.dataset.DATA_INFOS['train_full_aug_mixup']
        return probs_mix_ul

    def query(self, n):
        method_total = self.args.aug_metric.split('-')
        level = method_total[1]
        mode = method_total[-2]
        metric = method_total[-1]
        
        label_2_idx = {i: [] for i in range(len(self.dataset.CLASSES))}  
        global_lab_indices = np.arange(len(self.dataset.DATA_INFOS['train_full']))[self.dataset.INDEX_LB]
        global_ulab_indices = np.arange(len(self.dataset.DATA_INFOS['train_full']))[self.dataset.INDEX_ULB]
        for idx, data_elem in enumerate(self.dataset.DATA_INFOS['train']):
            temp_gt_label = data_elem['gt_label']
            
            if type(temp_gt_label) != int:
                temp_gt_label = temp_gt_label.item()
            label_2_idx[temp_gt_label].append(global_lab_indices[idx])
        
        probs_ori_full = None
        feats_ori_full = None
        if level == 'prob':
            probs_ori_full = self.predict(self.clf, 'train_full', 'prob', n_drop=self.n_drop).cpu()
        if level == 'feat':
            feats_ori_full = self.get_embedding(self.clf, 'train_full').cpu()
        
        
        info_u = np.zeros(len(self.dataset.DATA_INFOS['train_u']))  
        for _ in range(self.args.aug_trials):
            
            if level == 'prob':
                if mode[3:9] == 'random':
                    select_lab_list = [random.choice(global_lab_indices)
                                       for i in range(len(self.dataset.DATA_INFOS['train_u']))]
                    self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                        dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b, lam=1 - self.args.pace)
                        for i, idx_b in enumerate(select_lab_list)
                    ]
                    probs_mix_ul = self.predict(self.clf, ['train_full_aug_mixup'], 'prob', n_drop=self.n_drop).cpu()
                    if metric in ['secondorder', 'curorder']:
                        self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                            dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b,
                                 lam=1 - 2 * self.args.pace)
                            for i, idx_b in enumerate(select_lab_list)]
                        probs_mix_ul2 = self.predict(self.clf, ['train_full_aug_mixup'], 'prob',
                                                     n_drop=self.n_drop).cpu()
                    else:
                        probs_mix_ul2 = None
                    info_trial = prob_metric_performer(metric, self.dataset, probs_ori_full, probs_mix_ul,
                                                       select_lab_list, self.args.pace, probs_mix_ul2)

                elif mode[3:9] == 'refine':
                    info_trial = np.zeros(len(self.dataset.DATA_INFOS['train_u']))
                    for class_idx in label_2_idx.keys():
                        if len(label_2_idx[class_idx]) <= 0:
                            continue
                        select_lab_list = [random.choice(label_2_idx[class_idx])
                                           for i in range(len(self.dataset.DATA_INFOS['train_u']))]
                        self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                            dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b, lam=1 - self.args.pace)
                            for i, idx_b in enumerate(select_lab_list)
                        ]
                        probs_mix_ul = self.predict(self.clf, ['train_full_aug_mixup'], 'prob',
                                                    n_drop=self.n_drop).cpu()
                        if metric in ['secondorder', 'curorder']:
                            self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                                dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b,
                                     lam=1 - 2 * self.args.pace)
                                for i, idx_b in enumerate(select_lab_list)]
                            probs_mix_ul2 = self.predict(self.clf, ['train_full_aug_mixup'], 'prob',
                                                         n_drop=self.n_drop).cpu()
                        else:
                            probs_mix_ul2 = None
                        info_trial_temp = prob_metric_performer(metric, self.dataset, probs_ori_full, probs_mix_ul,
                                                                select_lab_list, self.args.pace, probs_mix_ul2)
                        if mode[9:12] == 'max':
                            info_trial = np.maximum(info_trial, info_trial_temp)
                        elif mode[9:12] == 'sum':
                            info_trial = info_trial + info_trial_temp
                        else:
                            raise NotImplementedError
                else:
                    raise NotImplementedError
            
            elif level == 'feat':
                if mode[3:9] == 'random':
                    select_lab_list = [random.choice(global_lab_indices)
                                       for i in range(len(self.dataset.DATA_INFOS['train_u']))]
                    self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                        dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b, lam=1 - self.args.pace)
                        for i, idx_b in enumerate(select_lab_list)]
                    feats_mix_ul = self.get_embedding(self.clf, 'train_full_aug_mixup').cpu()
                    if metric in ['secondorder', 'curorder']:
                        self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                            dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b,
                                 lam=1 - 2 * self.args.pace) for i, idx_b in enumerate(select_lab_list)]
                        feats_mix_ul2 = self.get_embedding(self.clf, 'train_full_aug_mixup').cpu()
                    else:
                        feats_mix_ul2 = None
                    info_trial = feat_metric_performer(metric, self.dataset, feats_ori_full,
                                                       feats_mix_ul, self.args.pace, feats_mix_ul2)

                elif mode[3:9] == 'refine':
                    info_trial = np.zeros(len(self.dataset.DATA_INFOS['train_u']))
                    for class_idx in label_2_idx.keys():
                        if len(label_2_idx[class_idx]) <= 0:
                            continue
                        select_lab_list = [random.choice(label_2_idx[class_idx])
                                           for i in range(len(self.dataset.DATA_INFOS['train_u']))]
                        self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                            dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b, lam=1 - self.args.pace)
                            for i, idx_b in enumerate(select_lab_list)]
                        feats_mix_ul = self.get_embedding(self.clf, 'train_full_aug_mixup').cpu()
                        if metric in ['secondorder', 'curorder']:
                            self.dataset.DATA_INFOS['train_full_aug_mixup'] = [
                                dict(split='train_full', idx_a=global_ulab_indices[i], idx_b=idx_b,
                                     lam=1 - 2 * self.args.pace) for i, idx_b in enumerate(select_lab_list)]
                            feats_mix_ul2 = self.get_embedding(self.clf, 'train_full_aug_mixup').cpu()
                        else:
                            feats_mix_ul2 = None
                        info_trial_temp = feat_metric_performer(metric, self.dataset, feats_ori_full,
                                                                feats_mix_ul, self.args.pace, feats_mix_ul2)
                        if mode[9:12] == 'max':
                            info_trial = np.maximum(info_trial, info_trial_temp)
                        elif mode[9:12] == 'sum':
                            info_trial = info_trial + info_trial_temp
                        else:
                            raise NotImplementedError
                else:
                    raise NotImplementedError
            else:
                raise NotImplementedError
            if mode[:3] == 'max': info_u = np.maximum(info_u, info_trial)
            elif mode[:3] == 'sum': info_u = info_u + info_trial
            else: raise NotImplementedError

        
        info_u = torch.tensor(info_u)
        return info_u.sort()[1][:n]
