import copy
import csv
import os

from .strategy import Strategy
import pdb
from scipy import stats
import numpy as np
from sklearn.metrics import pairwise_distances
import torch
from sklearn.cluster import KMeans
from tqdm import tqdm




class BadgeSampling(Strategy):
    def __init__(self,train_args, unlabeled_originlables , labeled_embeddinglist, unlabeled_embeddinglist,unlabeled_outputlist, unlabeled_img_path, add_ratio,embDim,num_classes,unlabeled_target):
        super(BadgeSampling, self).__init__(train_args, unlabeled_originlables,
                                    unlabeled_img_path, add_ratio,
                                    num_classes, unlabeled_target)
        self.labeled_embeddinglist = labeled_embeddinglist
        self.unlabeled_embeddinglist = unlabeled_embeddinglist
        self.unlabeled_outputlist = unlabeled_outputlist
        self.embDim = embDim
        self.num_classes = num_classes

    def query(self):

        probs = self.unlabeled_outputlist
        embedding = np.zeros([len(self.unlabeled_embeddinglist), self.embDim * self.num_classes])

        print('Creating gradient embeddings:')
        maxInds = np.argmax(probs, 1)
        for j in tqdm(range(len(self.unlabeled_embeddinglist)), desc='Computing gradient embeddings'):
            for c in range(self.num_classes):
                if c == maxInds[j]:
                    embedding[j][self.embDim * c: self.embDim * (c + 1)] = copy.deepcopy(self.unlabeled_embeddinglist[j]) * (
                                1 - probs[j][c])
                else:
                    embedding[j][self.embDim * c: self.embDim * (c + 1)] = copy.deepcopy(self.unlabeled_embeddinglist[j]) * (
                                -1 * probs[j][c])
        self.gradEmbedding = embedding
        select_data = self.chose()
        return select_data


    def chose(self):
        from tqdm import tqdm
        filepath = './Selcetion/{}/{}/{}/{}/'.format(self.dataset_name, self.classifier_name, self.select_strategy,
                                                    self.select_type)
        if not os.path.exists(filepath):
            os.makedirs(filepath)
        ft = open(str('{}{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio)), 'w', newline='')
        ft_csv = csv.writer(ft)

        ind = np.argmax([np.linalg.norm(s, 2) for s in self.gradEmbedding])
        mu = [self.gradEmbedding[ind]]
        dirlist = []
        lablelist = []

        target_len = int(self.add_ratio * self.num_train_set)
        with tqdm(total=target_len - 1, desc="KMeans++ selection") as pbar:
            while len(mu) < target_len:
                if len(mu) == 1:
                    D2 = pairwise_distances(self.gradEmbedding, mu).ravel().astype(float)
                else:
                    newD = pairwise_distances(self.gradEmbedding, [mu[-1]]).ravel().astype(float)
                    for i in range(len(self.gradEmbedding)):
                        if D2[i] > newD[i]:
                            D2[i] = newD[i]
                if sum(D2) == 0.0: pdb.set_trace()
                D2 = D2.ravel().astype(float)
                Ddist = (D2 ** 2) / sum(D2 ** 2)
                customDist = stats.rv_discrete(name='custm', values=(np.arange(len(D2)), Ddist))
                ind = customDist.rvs(size=1)[0]
                mu.append(self.gradEmbedding[ind])
                ft_csv.writerow([self.unlabeled_img_path[ind]] + [str(self.unlabeled_originlabels[ind])])
                dirlist.append(self.unlabeled_img_path[ind])
                lablelist.append(str(self.unlabeled_originlabels[ind]))
                pbar.update(1)

        select_data = list(zip(dirlist, lablelist))
        return select_data