# -*- coding: utf-8 -*-

import csv
import os

import numpy as np
from .strategy import Strategy
from datetime import datetime
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
class CoreSet(Strategy):
    def __init__(self,train_args, unlabeled_originlabels , labeled_embeddinglist, unlabeled_embeddinglist, unlabeled_img_path, add_ratio,labeled_protolist,num_classes,unlabeled_target):
        super(CoreSet, self).__init__(train_args, unlabeled_originlabels,
                                    unlabeled_img_path, add_ratio,
                                    num_classes, unlabeled_target)

        self.train_args = train_args

        self.dataset_name = train_args.dataset_name
        self.classifier_name = train_args.classifier_name
        self.select_strategy = train_args.select_strategy
        self.select_type = train_args.select_type
        self.select_ratio = train_args.select_ratio
        self.num_train_set = train_args.num_train_set

        self.labeled_embeddinglist = labeled_embeddinglist
        self.unlabeled_embeddinglist = unlabeled_embeddinglist
        self.unlabeled_img_path = unlabeled_img_path
        self.unlabeled_originlabels = unlabeled_originlabels

        self.labeled_protolist = labeled_protolist
        self.num_classes = num_classes
        self.unlabeled_target = unlabeled_target
    
    
    def chose(self):

        global scorelist, select_data

        filepath = './Selcetion/{}/{}/{}/{}/'.format(
            self.dataset_name, self.classifier_name, self.select_strategy, self.select_type
        )
        os.makedirs(filepath, exist_ok=True)

        ft = open(str('{}{}_{}.csv'.format(filepath, self.dataset_name, self.select_ratio)), 'w', newline='')
        ft_csv = csv.writer(ft)

        dist_ctr = pairwise_distances(self.unlabeled_embeddinglist, self.labeled_embeddinglist)
        min_dist = np.amin(dist_ctr, axis=1)

        dirlist = []
        lablelist = []

        total_select = int(self.add_ratio * self.num_train_set)

        for i in tqdm(range(total_select), desc="Selecting samples", ncols=80):

            idx = min_dist.argmax()
            ft_csv.writerow([self.unlabeled_img_path[idx], str(self.unlabeled_originlabels[idx])])
            dirlist.append(self.unlabeled_img_path[idx])
            lablelist.append(str(self.unlabeled_originlabels[idx]))

            dist_new_ctr = pairwise_distances(
                np.array(self.unlabeled_embeddinglist),
                np.array(self.unlabeled_embeddinglist)[[idx], :]
            )
            for j in range(len(self.unlabeled_embeddinglist)):
                min_dist[j] = min(min_dist[j], dist_new_ctr[j, 0])

        select_data = list(zip(dirlist, lablelist))
        return select_data
