import csv
import math
import random
import shutil
import os
import numpy as np
import torch.nn.functional as F
import torch
import os
from matplotlib import pyplot as plt
from src.utils.my_mkdir import mkdir
from sklearn import manifold,datasets
from sklearn.cluster import DBSCAN


def random_select(args,method_name, origin_labels, classifier_name, dataset_name, img_names, select_ratio, select_type, add_ratio):

    total_num = args.num_train_set
    # total_num = len(img_names)
    select_num = int(add_ratio * total_num)
    print('select_num: {}'.format(select_num))
    print('img_names: {}'.format(len(img_names)))

    filepath = './Selection/{}/{}/{}/{}/'.format(dataset_name, classifier_name, method_name, select_type)
    os.makedirs(filepath, exist_ok=True)


    csv_path = os.path.join(filepath, '{}_{}.csv'.format(dataset_name, select_ratio))
    with open(csv_path, 'w', newline='') as ft:
        ft_csv = csv.writer(ft)


        order = np.arange(len(img_names))
        np.random.shuffle(order)


        dirlist = []
        lablelist = []


        for i in order[:select_num]:
            ft_csv.writerow([img_names[i], str(origin_labels[i])])
            dirlist.append(img_names[i])
            lablelist.append(origin_labels[i])

    select_data = list(zip(dirlist, lablelist))
    print(f"Selected {select_num} samples out of {total_num} ({select_ratio * 100:.1f}%) and saved to {csv_path}")
    return select_data


def random_select_with_merge(train_args,originlables, img_names, labels,select_type,pool_embeddinglist,add_ratio,num_classes):

    global select_data
    filepath =  './Selcetion/{}/{}/{}/{}/'.format(train_args.dataset_name, train_args.classifier_name, train_args.select_strategy,select_type)

    if os.path.exists(filepath):
        pass
    else:
        os.makedirs(filepath)
    ft = open(str('{}{}_{}.csv'.format(filepath,train_args.dataset_name,train_args.select_ratio)), 'w', newline='')
    ft_csv = csv.writer(ft)
    rt = open(str('{}{}_{}_{}.csv'.format(filepath, train_args.dataset_name, train_args.select_ratio, 'remove')), 'w', newline='')
    rt_csv = csv.writer(rt)
    entropy_list = []


    clustering = DBSCAN(eps=0.5, min_samples=4).fit(pool_embeddinglist)
    lablelist = []
    dirlist = []
    T = 1
    merge_feature = []
    order = np.array(range(len(img_names)))
    random.shuffle(order)
    for v, i in enumerate(order):
        if i not in merge_feature:
            ft_csv.writerow([img_names[i]] + [str(originlables[i])])
            dirlist.append(img_names[i])
            lablelist.append(originlables[i])

            for k, j in enumerate(range(len(pool_embeddinglist))):
                if j != i and clustering.labels_[j] == clustering.labels_[i]:

                    diff = pool_embeddinglist[i] - pool_embeddinglist[j]
                    distance = np.linalg.norm(diff, axis=0)
                    if distance < T:
                        merge_feature.append(j)
                        rt_csv.writerow([img_names[j]] + [str(originlables[j])])
        else:
            continue
        if len(lablelist) >=450:
            break
    select_data = list(zip(dirlist, lablelist))
    return select_data