# -*- coding: utf-8 -*-


import math
import os

import numpy
import numpy as np
import csv

from sklearn.cluster import DBSCAN
from tqdm import tqdm

from .metric_select import *
from src.utils.my_mkdir import mkdir
from .strategy import Strategy
from scipy.spatial.distance import pdist


class EntropySeries(Strategy):
    def __init__(self, train_args, unlabeled_originlabels, unlabeled_outputlist, unlabeled_embeddinglist,
                 unlabeled_img_path, add_ratio, num_classes, unlabeled_target):
        super(EntropySeries, self).__init__(train_args, unlabeled_originlabels, unlabeled_img_path, add_ratio,
                                            num_classes, unlabeled_target)
        self.unlabeled_outputlist = unlabeled_outputlist
        self.unlabeled_embeddinglist = unlabeled_embeddinglist

    def entropy_select_with_merge(self, R=1):

        global select_data
        ft_csv, rt_csv = self.make_csv_path(True)

        entropy_list = self.compute_entropy(np.array(self.unlabeled_outputlist))
        if self.select_type == 'ADD-GOOD':
            order = (-np.array(entropy_list)).argsort()
        elif self.select_type == 'ADD-BAD':
            order = np.array(entropy_list).argsort()
        clustering = DBSCAN(eps=0.5, min_samples=4).fit(self.unlabeled_embeddinglist)
        lablelist = []
        dirlist = []
        maxlist = []
        merge_feature = []
        print('Selection begins')
        for v, i in tqdm(enumerate(order)):
            if i not in merge_feature:
                ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                dirlist.append(self.unlabeled_img_path[i])
                lablelist.append(self.unlabeled_originlabels[i])
                maxlist.append(self.unlabeled_embeddinglist[i])
                for k, j in enumerate(range(len(self.unlabeled_embeddinglist))):
                    if j != i and clustering.labels_[j] == clustering.labels_[i]:

                        distance = pdist(
                            np.vstack((self.unlabeled_embeddinglist[i], self.unlabeled_embeddinglist[j])), 'cosine')

                        if distance < R:
                            merge_feature.append(j)
                            rt_csv.writerow([self.unlabeled_img_path[j]] + [str(self.unlabeled_originlabels[j])])
            else:
                continue
            if len(lablelist) >= self.add_ratio * self.num_train_set:
                break
        select_data = list(zip(dirlist, lablelist))
        return select_data

    def entropy_select(self):

        global select_data
        ft_csv = self.make_csv_path()

        entropy_list = self.compute_entropy(np.array(self.unlabeled_outputlist))

        lablelist = []
        dirlist = []


        if self.select_type == 'ADD-GOOD':
            order = (-np.array(entropy_list)).argsort()
        elif self.select_type == 'ADD-BAD':
            order = (np.array(entropy_list)).argsort()
        for v, i in enumerate(order):
            ft_csv.writerow([str(self.unlabeled_img_path[i])] + [str(self.unlabeled_originlabels[i])])
            dirlist.append(str(self.unlabeled_img_path[i]))
            lablelist.append(self.unlabeled_originlabels[i])
            select_data = list(zip(dirlist, lablelist))
            if v >= self.add_ratio*self.num_train_set:
                break
        return select_data
