# -*- coding: utf-8 -*-

import csv
import math
import shutil
import os
import numpy as np
import torch.nn.functional as F
import torch
import os
from matplotlib import pyplot as plt
from sklearn.metrics import pairwise_distances, normalized_mutual_info_score
from sklearn.cluster import DBSCAN
from tqdm import tqdm

from .strategy import Strategy


class DistanceEntropySeries(Strategy):
    def __init__(self,train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target):
        super(DistanceEntropySeries, self).__init__(train_args,unlabeled_originlabels,unlabeled_img_path, add_ratio,num_classes,unlabeled_target)
        self.labeled_embeddinglist = labeled_embeddinglist
        self.unlabeled_embeddinglist = unlabeled_embeddinglist
        self.labeled_protolist = labeled_protolist
    def distance_entropy_select(self):

        global scorelist, select_data
        ft_csv = self.make_csv_path()

        # print(self.unlabeled_embeddinglist)

        distancelist = pairwise_distances(self.unlabeled_embeddinglist,self.labeled_protolist)
        distance_softmax = F.softmax(torch.from_numpy(-distancelist), dim=1)
        distance_softmax = distance_softmax.numpy()

        distance_entropy = self.compute_entropy(distance_softmax)
        if self.select_type=='ADD-GOOD':

            order = (-np.array(distance_entropy)).argsort()
        elif self.select_type=='ADD-BAD':

            order = np.array(distance_entropy).argsort()

        lablelist = []
        dirlist = []

        print("Selection begins")
        for v, i in tqdm(enumerate(order)):
            if v < self.add_ratio*self.num_train_set:
                ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                dirlist.append(self.unlabeled_img_path[i])
                lablelist.append(str(self.unlabeled_originlabels[i]))
            else:
                break
        select_data = list(zip(dirlist, lablelist))
        return select_data
    def distance_entropy_with_merge(self):

        global scorelist, select_data
        ft_csv,rt_csv = self.make_csv_path(True)

        distancelist = pairwise_distances(self.unlabeled_embeddinglist,self.labeled_protolist)
        distance_softmax = F.softmax(torch.from_numpy(-distancelist), dim=1)
        distance_softmax = distance_softmax.numpy()

        distance_entropy = self.compute_entropy(distance_softmax)
        if self.select_type=='ADD-GOOD':

            order = (-np.array(distance_entropy)).argsort()
        elif self.select_type=='ADD-BAD':

            order = np.array(distance_entropy).argsort()
        clustering = DBSCAN(eps=0.5, min_samples=4).fit(self.unlabeled_embeddinglist)

        lablelist = []
        dirlist = []
        T = 3
        merge_feature = []
        for v, i in enumerate(order):
            if i not in merge_feature:
                ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                dirlist.append(self.unlabeled_img_path[i])
                lablelist.append(str(self.unlabeled_originlabels[i]))
                for k, j in enumerate(range(len(self.unlabeled_embeddinglist))):
                    if j != i and clustering.labels_[j] == clustering.labels_[i]:

                        diff = self.unlabeled_embeddinglist[i] - self.unlabeled_embeddinglist[j]
                        distance = np.linalg.norm(diff, axis=0)
                        if distance < T:
                            merge_feature.append(j)
                            rt_csv.writerow([self.unlabeled_img_path[j]] + [str(self.unlabeled_originlabels[j])])
            else:
                continue
            if len(lablelist) == self.add_ratio * self.num_train_set:
                break
        select_data = list(zip(dirlist, lablelist))
        return select_data
    def distance_entropy_aver(self):

        global scorelist, select_data
        ft_csv,rt_csv = self.make_csv_path(True)
        originlableslist = [[]]*self.num_classes
        img_nameslist = [[]]*self.num_classes
        for i in range(len(self.unlabeled_img_path)):
            originlableslist[self.unlabeled_target[i]] = originlableslist[self.unlabeled_target[i]] + [self.unlabeled_originlabels[i]]
            img_nameslist[self.unlabeled_target[i]] = img_nameslist[self.unlabeled_target[i]] + [self.unlabeled_img_path[i]]

        distancelist = pairwise_distances(self.unlabeled_embeddinglist, self.labeled_protolist)
        distance_softmax = F.softmax(torch.from_numpy(-distancelist), dim=1)
        distance_softmax = distance_softmax.numpy()

        print(len(distance_softmax))
        distance_entropy = self.compute_entropy(distance_softmax)
        distance_entropy_list = [[]]*self.num_classes
        for i in range(len(distance_entropy)):
            distance_entropy_list[self.unlabeled_target[i]] = distance_entropy_list[self.unlabeled_target[i]] + [distance_entropy[i]]
        distance_entropy_mean = []
        for k in distance_entropy_list:
            distance_entropy_mean.append(np.mean(k))
        txtpath = './Data/aver_distanceEntropy_distribution.txt'
        with open(txtpath, "a") as myfile:
            for k in range(len(distance_entropy_mean)):
                myfile.write(str(distance_entropy_mean[k]) + ' ')
            myfile.write('\n')
        order = []
        for i in range(len(distance_entropy_list)):
            if self.select_type == 'ADD-GOOD':

                order.append((-np.array(distance_entropy_list[i])).argsort())
            elif self.select_type == 'ADD-BAD':

                order.append((np.array(distance_entropy_list[i])).argsort())  
        # print(len(order))
        lablelist = []
        dirlist = []
        i=0
        print("Selection begins")
        for j in range(len(order)):
            for v, i in enumerate(order[j]):
                # print(len(order[j]))
                if v < self.add_ratio*self.num_train_set:
                    ft_csv.writerow([img_nameslist[j][i]] + [str(originlableslist[j][i])])
                    dirlist.append(img_nameslist[j][i])
                    lablelist.append(str(originlableslist[j][i]))
                else:
                    break
        select_data = list(zip(dirlist, lablelist))

        return select_data

    def distance_entropy_v2(self, mutual_inf):

        global scorelist, select_data
        ft_csv,rt_csv = self.make_csv_path(True)
        distance_entropy_v2_list = []

        distancelist = pairwise_distances(self.unlabeled_embeddinglist, self.labeled_protolist)
        # distancelist: n*10
        disorder = distancelist.argsort()
        for i in range(len(distancelist)):
            distance_one = distancelist[i][disorder[i][0]]
            distance_two = distancelist[i][disorder[i][1]]
            distance_twolist = [distance_one, distance_two]
            one_distance_softmax = F.softmax(torch.from_numpy(-np.array(distance_twolist)), dim=0)
            one_distance_entropy = self.compute_entropy_dim1(one_distance_softmax.numpy())
            distance_entropy_v2_list.append(one_distance_entropy)

        if self.select_type=='ADD-GOOD':

            order = (-np.array(distance_entropy_v2_list)).argsort()
        elif self.select_type=='ADD-BAD':

            order= np.array(distance_entropy_v2_list).argsort()
        else:
            raise Exception('Select Type Error!')

        lablelist = []
        dirlist = []
        # maxlist = []
        eliminate = []
        count = 0
        for v, i in enumerate(order):
            if mutual_inf:
                # x_fea = (np.array(self.unlabeled_embeddinglist[i]) * 1).astype('uint8')
                if i not in eliminate:

                    mutual_distance_list = pairwise_distances([self.unlabeled_embeddinglist[i]], self.labeled_embeddinglist)
                    if np.any(mutual_distance_list[0] < 1):
                        eliminate.append(i)

                    if np.any(mutual_distance_list[0] < 1):
                        continue
                    else:
                        count += 1
                        print(f'\r\033[36m{100 * count / int(self.add_ratio * self.num_train_set)}%[',
                              ">" * int(count*1000 / self.num_train_set) + "·" * int(self.add_ratio * 1000-count*1000 / self.num_train_set) + "]", end="")
                        ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                        dirlist.append(self.unlabeled_img_path[i])
                        lablelist.append(str(self.unlabeled_originlabels[i]))

                        un_mutual_distance_list = pairwise_distances([self.unlabeled_embeddinglist[i]],
                                                                  self.unlabeled_embeddinglist, metric='cosine')
                        un_mutual_ind = np.where(un_mutual_distance_list[0]<0.1)[0]
                        for k in un_mutual_ind:
                            if k != i and k not in eliminate:
                                eliminate.append(k)
                                rt_csv.writerow(
                                    [self.unlabeled_img_path[k]] + [str(self.unlabeled_originlabels[k])])



            else:
                ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                dirlist.append(self.unlabeled_img_path[i])
                lablelist.append(str(self.unlabeled_originlabels[i]))
                # maxlist.append(self.unlabeled_embeddinglist[i])#画图用
            if len(dirlist) >= self.add_ratio * self.num_train_set:
                break
        print('')
        select_data = list(zip(dirlist, lablelist))


        return select_data

    def edge_distance_entropy(self, od):

        global scorelist, select_data
        ft_csv = self.make_csv_path()

        distance_entropy_v2pro_list = []
        distance_entropy = []
        Entropy = lambda p: sum(-p * np.log2(p))

        if od:
            for i in range(len(self.unlabeled_embeddinglist)):
                diff_proto = np.array(self.unlabeled_embeddinglist[i]) - self.labeled_protolist
                # print(diff_proto)
                distance_proto = np.linalg.norm(diff_proto, axis=1)
                distance_softmax = F.softmax(torch.from_numpy(-distance_proto), dim=0)
                distance_softmax = distance_softmax.numpy()

                entropy = Entropy(-distance_softmax)
                distance_entropy.append(entropy)

            order = (-np.array(distance_entropy)).argsort()

            lablelist = []
            dirlist = []
            self.unlabeled_img_path = np.delete(self.unlabeled_img_path,order[0:1000],axis=0)
            self.unlabeled_originlabels = np.delete(self.unlabeled_originlabels,order[0:1000],axis=0)
            self.unlabeled_embeddinglist = np.delete(self.unlabeled_embeddinglist, order[0:1000], axis=0)


        for i in range(len(self.unlabeled_embeddinglist)):

            diff_proto = np.array(self.unlabeled_embeddinglist[i]) - self.labeled_protolist
            # print(diff_proto)
            distance_proto = np.linalg.norm(diff_proto,axis=1)
            disproto_order = distance_proto.argsort()

            TrainEmbeddinglist_two = np.array(self.labeled_embeddinglist)[[disproto_order[0], disproto_order[1]]]

            distance_all_one = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(TrainEmbeddinglist_two[0]),axis=1))
            distance_all_two = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(TrainEmbeddinglist_two[1]),axis=1))


            distance_order_one_order = distance_all_one.argsort()
            distance_order_two_order = distance_all_two.argsort()

            alpha = (1-(self.select_ratio-self.begin_select_ratio))**3
            long1 = len(np.array(TrainEmbeddinglist_two[0]))
            long2 = len(np.array(TrainEmbeddinglist_two[1]))
            edge_number1 = int(alpha*long1)
            edge_number2 = int(alpha*long2)
            edge_index1 = np.array(distance_order_one_order)[0:edge_number1]
            edge_index2 = np.array(distance_order_two_order)[0:edge_number2]
            edge_feature_one = np.mean(np.array(TrainEmbeddinglist_two[0])[edge_index1],axis=0)
            edge_feature_two = np.mean(np.array(TrainEmbeddinglist_two[1])[edge_index2],axis=0)
            edge_distance_one = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(edge_feature_one)))
            edge_distance_two = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(edge_feature_two)))

            distance_twolist = [edge_distance_one, edge_distance_two]
            one_distance_softmax = F.softmax(torch.from_numpy(-np.array(distance_twolist)), dim=0)
            distance_entropy = Entropy(one_distance_softmax.numpy())
            # distance_diff = abs(distance_one-distance_two)
            # print(one_distance_entropy)
            distance_entropy_v2pro_list.append(distance_entropy)

        distance_entropy_list = [[]]*self.num_classes
        for i in range(len(distance_entropy_v2pro_list)):
            distance_entropy_list[self.unlabeled_target[i]] = distance_entropy_list[self.unlabeled_target[i]] + [distance_entropy_v2pro_list[i]]

        if self.select_type=='ADD-GOOD':

            order = (-np.array(distance_entropy_v2pro_list)).argsort()
        elif self.select_type=='ADD-BAD':

            order= np.array(distance_entropy_v2pro_list).argsort()
        else:
            raise Exception('Select Type Error!')


        lablelist = []
        dirlist = []
        maxlist = []

        for v, i in enumerate(order):
            ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
            dirlist.append(self.unlabeled_img_path[i])
            lablelist.append(str(self.unlabeled_originlabels[i]))
            maxlist.append(self.unlabeled_embeddinglist[i])
            if len(dirlist) == 50000*self.add_ratio:
                break

        select_data = list(zip(dirlist, lablelist))

        return select_data

    def edge_distance_entropy_with_feature_merge(self):

        global scorelist, select_data
        ft_csv,rt_csv = self.make_csv_path(True)

        distance_entropy_v2pro_list = []
        distance_entropy = []
        Entropy = lambda p: sum(-p * np.log2(p))

        for i in range(len(self.unlabeled_embeddinglist)):

            diff_proto = np.array(self.unlabeled_embeddinglist[i]) - self.labeled_protolist
            # print(diff_proto)
            distance_proto = np.linalg.norm(diff_proto,axis=1)
            disproto_order = distance_proto.argsort()

            TrainEmbeddinglist_two = np.array(self.labeled_embeddinglist)[[disproto_order[0], disproto_order[1]]]

            distance_all_one = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(TrainEmbeddinglist_two[0]),axis=1))
            distance_all_two = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(TrainEmbeddinglist_two[1]),axis=1))


            distance_order_one_order = distance_all_one.argsort()
            distance_order_two_order = distance_all_two.argsort()

            alpha = (1-(self.select_ratio-self.begin_select_ratio))**3
            long1 = len(np.array(TrainEmbeddinglist_two[0]))
            long2 = len(np.array(TrainEmbeddinglist_two[1]))
            edge_number1 = int(alpha*long1)
            edge_number2 = int(alpha*long2)
            edge_index1 = np.array(distance_order_one_order)[0:edge_number1]
            edge_index2 = np.array(distance_order_two_order)[0:edge_number2]
            edge_feature_one = np.mean(np.array(TrainEmbeddinglist_two[0])[edge_index1],axis=0)
            edge_feature_two = np.mean(np.array(TrainEmbeddinglist_two[1])[edge_index2],axis=0)
            edge_distance_one = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(edge_feature_one)))
            edge_distance_two = (np.linalg.norm(
                np.array(self.unlabeled_embeddinglist[i]) - np.array(edge_feature_two)))

            distance_twolist = [edge_distance_one, edge_distance_two]
            one_distance_softmax = F.softmax(torch.from_numpy(-np.array(distance_twolist)), dim=0)
            distance_entropy = Entropy(one_distance_softmax.numpy())

            distance_entropy_v2pro_list.append(distance_entropy)

        distance_entropy_list = [[]]*self.num_classes
        for i in range(len(distance_entropy_v2pro_list)):
            distance_entropy_list[self.unlabeled_target[i]] = distance_entropy_list[self.unlabeled_target[i]] + [distance_entropy_v2pro_list[i]]

        if self.select_type=='ADD-GOOD':

            order = (-np.array(distance_entropy_v2pro_list)).argsort()
        elif self.select_type=='ADD-BAD':

            order= np.array(distance_entropy_v2pro_list).argsort()
        else:
            raise Exception('Select Type Error!')

        clustering = DBSCAN(eps=0.5, min_samples=4).fit(self.unlabeled_embeddinglist)

        lablelist = []
        dirlist = []
        maxlist = []
        T = 3
        merge_feature = []
        for v, i in enumerate(order):
            if i not in merge_feature:
                ft_csv.writerow([self.unlabeled_img_path[i]] + [str(self.unlabeled_originlabels[i])])
                dirlist.append(self.unlabeled_img_path[i])
                lablelist.append(str(self.unlabeled_originlabels[i]))
                maxlist.append(self.unlabeled_embeddinglist[i])
                for k, j in enumerate(range(len(self.unlabeled_embeddinglist))):
                    if j != i and clustering.labels_[j] == clustering.labels_[i]:

                        diff = self.unlabeled_embeddinglist[i] - self.unlabeled_embeddinglist[j]
                        distance = np.linalg.norm(diff, axis=0)
                        if distance < 3:
                            merge_feature.append(j)
                            rt_csv.writerow([self.unlabeled_img_path[j]] + [str(self.unlabeled_originlabels[j])])
            else:
                continue
            if len(lablelist) == self.add_ratio * 50000:
                break
        # print(merge_feature)
        select_data = list(zip(dirlist, lablelist))

        return select_data