import math
import random
import torch
import kmedoids
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import pycls.datasets.utils as ds_utils
from scipy.spatial import distance
from sklearn.metrics.pairwise import rbf_kernel 


class Active_Silhouette(nn.Module): 
    def __init__(self, cfg, lSet, uSet, budgetSize, delta, p=2):
        super().__init__()
        self.cfg = cfg
        self.ds_name = self.cfg['DATASET']['NAME']
        self.seed = self.cfg['RNG_SEED']
        if cfg.DATASET.NAME.lower() in ['CIFAR10_imbalanced', 'CIFAR10_all_imbalanced', "trpb", 'optdigits', 'phishing']:
            self.all_features, self.all_labels = ds_utils.load_features_labelexclu(self.ds_name, self.seed, train=True, normalized=False)
        else:
            self.all_features = ds_utils.load_features(self.ds_name, self.seed)
        self.lSet = lSet
        self.uSet = uSet
        self.budgetSize = budgetSize
        self.delta = delta
        
        self.lfeat = self.all_features[np.array(self.lSet).astype(int)]
        self.ufeat = self.all_features[np.array(self.uSet).astype(int)]
        self.p = p
        print(p) 

    @torch.no_grad() #A
    def dist(self, x,y, gamma=0.1):
        d = torch.cdist(x, y)
        return (d-gamma).clamp(min=0)  

    # @torch.no_grad() #B
    # def dist(self, x,y, gamma=0.1): 
    #     d = torch.cdist(x, y)
    #     return (d>gamma).float()    
    
    @torch.no_grad()
    def select_samples(self, batch_size=4096, device='cuda', macro_avg=True, fix_gamma=None):  

        print(f'Start selecting {self.budgetSize} samples.')
        golden_number_1 = (3 + 5 ** 0.5)/2
        # gamma_tested = [0.5, 1.5]   #budget10 cifiar al_initial
        # gamma_tested = [0.7, 1.4]   
        gamma_tested = [0.7, 3.4] # GENERAL
        # gamma_tested = [0.1,10] # phishing
        # gamma_tested = [20,30]  # digihands  [20,30]
        # if self.cfg.DATASET.NAME.lower() in ["trpb_balanced"]:
            # gamma_tested = [1.3, 2.4]   # trpb_balanced
        # elif self.cfg.DATASET.NAME.lower() in ["trpb"]:
        #     gamma_tested = [2.4, 3.5]   # trpb
        # else:
        #     gamma_tested = [0.7, 1.4]
        if fix_gamma is not None:
            gamma = fix_gamma
        else:
            gamma = gamma_tested[0] + (gamma_tested[1]-gamma_tested[0])/golden_number_1
        best_gamma = gamma  
        best_sihlo_coef = 1.1    
        if self.lfeat.shape[0]<2:  
            # min_index_2 = np.where((self.uSet==12763) == 1)[0].tolist() 
            # min_index_2 += np.where((self.uSet==17662)==1)[0].tolist()
            # min_index_2 = random.sample(range(self.ufeat.shape[0]),2) 
            fp = kmedoids.fasterpam(torch.cdist(torch.tensor(self.ufeat).float(), torch.tensor(self.ufeat).float()), 2)
            min_index_2 = fp.medoids
            print(min_index_2)
        for _ in range(20):
            uset = self.uSet 
            activeSet = []
            cuda_feats = torch.tensor(self.ufeat).float().to(device)
            cuda_lfeats = torch.tensor(self.lfeat).float().to(device)
            sihlo_coef_sum = 0                
            for _ in range(self.budgetSize):                 
                if cuda_lfeats.shape[0]<2:  
                    min_index_ = min_index_2[cuda_lfeats.shape[0]]
                    # min_index_ = [4095,8191][cuda_lfeats.shape[0]]
                else:
                    sihlo_coef_ = 1.1
                    d_min2, idx_min2 = torch.topk(
                        F.pad(self.dist(cuda_lfeats, cuda_feats, gamma=gamma),
                        pad=(0,0,0,max(0,2-cuda_lfeats.shape[0])),value=float('inf')),
                        k=2, axis=0, largest=False)
                    if cuda_lfeats.shape[0]==0:
                        idx_min2 = idx_min2.zero_()
                    # d_1_2 = d_min2[0]/d_min2[1]
                    for i in range(math.ceil(cuda_feats.shape[0]/batch_size)):
                        # distance comparisons are done in batches to reduce memory consumption
                        cur_feats = cuda_feats[i * batch_size: (i + 1) * batch_size]
                        d_uu = self.dist(cur_feats, cuda_feats, gamma=gamma)  
                        # d_12_3 = torch.clamp(d_uu/d_min2[1], max=1)
                        # sihlo_candi = torch.clamp(d_12_3/d_1_2, max=1/(d_12_3/d_1_2)).nan_to_num(1)   
                        d_1_2 = d_uu.clamp(max=d_min2[1])
                        sihlo_candi = (d_min2[0]/d_1_2).nan_to_num(1)    
                        sihlo_candi = sihlo_candi.clamp(max=1/sihlo_candi)
                        if cuda_lfeats.shape[0]>1:
                            if macro_avg:
                                grper = idx_min2[0].repeat(d_uu.shape[0],1)
                                grper[d_min2[0] > d_uu] = cuda_lfeats.shape[0]                    
                                sihlo_candi = sihlo_candi.new_zeros((cur_feats.shape[0],cuda_lfeats.shape[0]+1)).scatter_reduce(1,grper,sihlo_candi,reduce='mean',include_self=False)  
                        else:
                            sihlo_candi.fill_(1)               
                        sihlo_candi = sihlo_candi.mean(1)#.clamp(min=d_min2[0][i * batch_size: (i + 1) * batch_size]==0)  #clamp: whether to include points within radius gamma 
                        sihlo_coef, idx_min = sihlo_candi.min(0)
                        if sihlo_coef<sihlo_coef_:                    
                            min_index_ = idx_min.item() + i * batch_size
                            sihlo_coef_ = sihlo_coef.item() 
                    sihlo_coef_sum += sihlo_coef_        
                cuda_lfeats = torch.vstack([cuda_lfeats,cuda_feats[min_index_]])
                cuda_feats = cuda_feats[torch.arange(cuda_feats.size(0)).to(device) != min_index_] 
                activeSet.append(uset[min_index_]) 
                uset = np.delete(uset, [min_index_]) 
            print('gamma:{0}, silhouette:{1}'.format(gamma,sihlo_coef_sum/self.budgetSize))
            # if sihlo_coef_sum/self.budgetSize < best_sihlo_coef:
            #     best_sihlo_coef = sihlo_coef_sum/self.budgetSize
            #     best_activeSet = activeSet
            #     best_uset = uset
            #     best_gamma = gamma   
            #==================cifir10 budget10=======================
            if sihlo_coef_sum/self.budgetSize == best_sihlo_coef:
                if gamma < best_gamma:
                    best_gamma = gamma
                    best_sihlo_coef = sihlo_coef_sum/self.budgetSize
                    best_activeSet = activeSet
                    best_uset = uset
            if sihlo_coef_sum/self.budgetSize < best_sihlo_coef:
                best_gamma = gamma
                best_sihlo_coef = sihlo_coef_sum/self.budgetSize
                best_activeSet = activeSet
                best_uset = uset              
            #======================================================     
            gamma_tested.append(gamma)  
            above_best = min([_ for _ in gamma_tested if best_gamma < _])
            below_best = max([_ for _ in gamma_tested if best_gamma > _])                    
            gamma = below_best + (above_best-below_best)/golden_number_1                     
            if abs(gamma-best_gamma)<1e-7:                        
                gamma = above_best - (above_best-below_best)/golden_number_1  
            if fix_gamma is not None:
                break  
        print(f'Finished the selection of {len(best_activeSet)} samples.')
        print(f'Active set is {best_activeSet}')
        return best_activeSet, np.array(sorted(best_uset)), best_gamma
    
    


class Kernelherding:
    def __init__(self, cfg, lSet, uSet, budgetSize, delta):
        self.cfg = cfg
        self.ds_name = self.cfg['DATASET']['NAME']
        self.seed = self.cfg['RNG_SEED']
        if cfg.DATASET.NAME.lower() in ['CIFAR10_imbalanced', 'CIFAR100_imbalanced', 'CIFAR10_all_imbalanced', 'CIFAR100_all_imbalanced', "trpb","trpb_umap","octanoate", 'butyrate', 'acetate', '60butyrate', '90butyrate', 'optdigits', 'phishing']:
            self.all_features, self.all_labels = ds_utils.load_features_labelexclu(self.ds_name, self.seed, train=True, normalized=False)
        elif cfg.DATASET.NAME.lower() in ['octanoate_700_2', 'butyrate_700_2', 'acetate_700_2', '60butyrate_700_2', '90butyrate_700_2', 'octanoate_700_1028', 'butyrate_700_1028', 'acetate_700_1028', '60butyrate_700_1028', '90butyrate_700_1028', 'sysdata']:
            self.all_features, self.all_labels = ds_utils.load_features_labelexclu(self.ds_name, self.seed, train=False, normalized=False)
        else:
            self.all_features = ds_utils.load_features(self.ds_name, self.seed)
        self.lSet = lSet
        self.uSet = uSet
        self.budgetSize = budgetSize
        self.delta = delta
        
        # print(len(uSet))
        # print(self.all_features.shape)
        self.lfeat = self.all_features[np.array(self.lSet).astype(int)]
        self.ufeat = self.all_features[np.array(self.uSet).astype(int)]

    def gaussian_calculate(self, x,y,gamma=0.03):
        print('gamma: ',gamma)
        return np.float32(np.exp(-gamma * (distance.cdist(x, y)**2)))
    
    def gaussian_calculate_gpu(self, x,y,gamma=0.03):
        return torch.exp(-gamma * (torch.cdist(x, y)**2)).float()
    

    def select_samples(self, batch_size=10):  #### this is kernel herding  ####   # batch_size = 500 or 50
        
        if np.asarray(self.lfeat).shape[0] != 0:
            similarity_lu = self.gaussian_calculate(self.lfeat, self.ufeat)
        else:
            similarity_lu = np.zeros((0, np.asarray(self.ufeat).shape[0]))

        print('Start calculating the first term')
        # similarity computations are done in GPU
        similarity_uu = []
        cuda_feats = torch.tensor(self.ufeat).cuda().float()
        # print(len(self.ufeat))
        for i in range(len(self.ufeat) // batch_size):
            # distance comparisons are done in batches to reduce memory consumption
            cur_feats = cuda_feats[i * batch_size: (i + 1) * batch_size]
            dist = self.gaussian_calculate_gpu(cur_feats, cuda_feats)
         
            similarity_uu.append(dist.cpu()) 
           
            if i == (len(self.ufeat) // batch_size) - 1:
                add_feats = cuda_feats[(i + 1) * batch_size:]
                add_dist = self.gaussian_calculate_gpu(add_feats, cuda_feats)
                similarity_uu.append(add_dist.cpu()) 

        similarity_uu = np.vstack(similarity_uu)
        print('Finished calculating the first term')


        print(f'Start selecting {self.budgetSize} samples.')

        selected = []

        len_similarity_lu = similarity_lu.shape[0]
        sum_similarity_lu = np.sum(similarity_lu,axis=0)
        u_feat = self.ufeat

        activeSet = []
        uset = self.uSet

        for i in range(self.budgetSize):

            meansimilarity_uu = np.mean(similarity_uu, axis=0)

            sumave_similarity_lu = sum_similarity_lu / (len_similarity_lu+1)


            arg_max_obj = meansimilarity_uu - sumave_similarity_lu 
            max_index = arg_max_obj.argmax()
            # selected.append(max_index)
            activeSet.append(uset[max_index])
            uset = np.delete(uset, [max_index])


            #### update 
            similarity_uu = np.delete(similarity_uu, [max_index] , axis=0)
            similarity_uu = np.delete(similarity_uu, [max_index] , axis=1)

            sum_similarity_lu = sum_similarity_lu + self.gaussian_calculate(np.asarray(self.ufeat)[[max_index]], u_feat)
            sum_similarity_lu = np.delete(sum_similarity_lu, max_index)
            u_feat = np.delete(u_feat, [max_index] , axis=0)
            # sumsimilarity_lu = sumsimilarity_lu, [max_index], axis=1)
            len_similarity_lu = len_similarity_lu + 1
            ## updata uset
            # del uset[max_index]
            


        # assert len(selected) == self.budgetSize, 'added a different number of samples'
        # activeSet = self.uSet[selected]

        remainSet = np.array(sorted(list(set(self.uSet) - set(activeSet))))

        print(f'Finished the selection of {len(activeSet)} samples.')
        print(f'Active set is {activeSet}')
        return activeSet, remainSet
    

class kmeidoid_select_initial:
    def __init__(self, cfg, lSet, uSet, budgetSize):
        self.cfg = cfg
        self.ds_name = self.cfg['DATASET']['NAME']
        self.seed = self.cfg['RNG_SEED']
        if cfg.DATASET.NAME.lower() in ['CIFAR10_imbalanced', 'CIFAR100_imbalanced', 'CIFAR10_all_imbalanced', 'CIFAR100_all_imbalanced', "trpb","trpb_umap","octanoate", 'butyrate', 'acetate', '60butyrate', '90butyrate', 'optdigits', 'phishing']:
            self.all_features, self.all_labels = ds_utils.load_features_labelexclu(self.ds_name, self.seed, train=True, normalized=False)
        elif cfg.DATASET.NAME.lower() in ['octanoate_700_2', 'butyrate_700_2', 'acetate_700_2', '60butyrate_700_2', '90butyrate_700_2', 'octanoate_700_1028', 'butyrate_700_1028', 'acetate_700_1028', '60butyrate_700_1028', '90butyrate_700_1028', 'sysdata']:
            self.all_features, self.all_labels = ds_utils.load_features_labelexclu(self.ds_name, self.seed, train=False, normalized=False)
        else:
            self.all_features = ds_utils.load_features(self.ds_name, self.seed)
        self.lSet = lSet
        self.uSet = uSet
        self.budgetSize = budgetSize

        
        # print(len(uSet))
        # print(self.all_features.shape)
        self.lfeat = self.all_features[np.array(self.lSet).astype(int)]
        self.ufeat = self.all_features[np.array(self.uSet).astype(int)]

    def gaussian_calculate(self, x,y,gamma=1):
        return np.float32(np.exp(-gamma * (distance.cdist(x, y)**2)))
    
    def gaussian_calculate_gpu(self, x,y,gamma=1):
        return torch.exp(-gamma * (torch.cdist(x, y)**2)).float()
    

    def select_samples(self, batch_size=10):  #### this is kernel herding  ####   # batch_size = 500 or 50

        # similarity computations are done in GPU
        similarity_uu = []
        cuda_feats = torch.tensor(self.ufeat).cuda().float()
        for i in range(len(self.ufeat) // batch_size):
            cur_feats = cuda_feats[i * batch_size: (i + 1) * batch_size]
            dist = self.gaussian_calculate_gpu(cur_feats, cuda_feats)
            similarity_uu.append(dist.cpu()) 
           
            if i == (len(self.ufeat) // batch_size) - 1:
                add_feats = cuda_feats[(i + 1) * batch_size:]
                add_dist = self.gaussian_calculate_gpu(add_feats, cuda_feats)
                similarity_uu.append(add_dist.cpu()) 

        similarity_uu = np.vstack(similarity_uu)
        print(f'Start selecting {self.budgetSize} samples.')

        selected = []

        u_feat = self.ufeat
        activeSet = []
        uSet = self.uSet

        fp = kmedoids.fasterpam(similarity_uu, self.budgetSize)
        selected = fp.medoids

        activeSet = uSet[selected]
        remainSet = np.array(sorted(list(set(uSet) - set(activeSet))))

        print(f'Active set is {activeSet}')
        return activeSet, remainSet