# -*- coding: utf-8 -*-

import csv
import numpy as np
from .strategy import Strategy
import os
from sklearn.metrics import pairwise_distances
import torch.nn.functional as F
import torch

class CoreWeightEntropy(Strategy):
    def __init__(self,train_args,labeled_per_embeddinglist, unlabeled_originlabels , unlabeled_embeddinglist,unlabeled_outputlist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target):
        super(CoreWeightEntropy, self).__init__(train_args,unlabeled_originlabels,unlabeled_img_path, add_ratio,num_classes,unlabeled_target)
        self.unlabeled_embeddinglist = unlabeled_embeddinglist
        self.unlabeled_outputlist = unlabeled_outputlist

        self.labeled_protolist = labeled_protolist
        self.num_classes = num_classes
        self.labeled_per_embeddinglist = labeled_per_embeddinglist
    def chose(self):

        global scorelist, select_data
        ft_csv = self.make_csv_path()

        distancelist = pairwise_distances(self.unlabeled_embeddinglist,self.labeled_protolist)
        distance_softmax = F.softmax(torch.from_numpy(distancelist), dim=1)
        core_weight_entropy = self.compute_entropy(distance_softmax)
        print("Selection begins")
        lablelist = []
        dirlist = []
        idxlist = []
        pseudo_unlabeled_target = np.argmax(self.unlabeled_outputlist, axis=1)
        count = 0
        while len(dirlist)<int(self.add_ratio * self.num_train_set):
            idx_list = np.argsort(-core_weight_entropy)
            q = 1
            for idx in idx_list:
                if q > int(1000):
                    break
                else:
                    if idx in idxlist:
                        continue
                    else:
                        idxlist.append(idx)
                    ft_csv.writerow([self.unlabeled_img_path[idx]] + [str(self.unlabeled_originlabels[idx])])
                    dirlist.append(self.unlabeled_img_path[idx])
                    lablelist.append(str(self.unlabeled_originlabels[idx]))
                    count+=1
                    print(f'\r\033[36m{100*count/int(self.add_ratio * self.num_train_set)}%[', ">" * int(count/50) + "·" * int(self.add_ratio * self.num_train_set/50-count/50) + "]", end="")
                    self.labeled_per_embeddinglist[pseudo_unlabeled_target[idx]] = self.labeled_per_embeddinglist[pseudo_unlabeled_target[idx]] + [self.unlabeled_embeddinglist[idx]]
                q+=1

            new_core_group = np.zeros(shape=(self.num_classes, len(self.labeled_protolist[0])))
            for i in range(len(self.labeled_per_embeddinglist)):
                # new_core_group = np.concatenate([self.labeled_protolist, np.array(i_corelist)], axis=0)
                new_core_group[i] = np.mean(self.labeled_per_embeddinglist[i],axis=0)
            distancelist = pairwise_distances(self.unlabeled_embeddinglist, new_core_group)
            distance_softmax = F.softmax(torch.from_numpy(distancelist), dim=1)#.numpy()
            core_weight_entropy = self.compute_entropy(distance_softmax)
        print('')
        select_data = list(zip(dirlist, lablelist))
        return select_data
