import csv
import shutil
import os
import numpy as np
import torch.nn.functional as F
import torch
import os
from sklearn.cluster import DBSCAN
from matplotlib import pyplot as plt
from tqdm import tqdm
from src.utils.my_mkdir import mkdir


def learning_loss(method_name, originlables, classifier_name, dataset_name,
                                                 outputlist, img_names, select_ratio, labels, select_type,
                                                 add_ratio, num_classes,losslist):

    global scorelist, select_data
    filepath = './Selcetion/{}/{}/{}/{}/'.format(dataset_name, classifier_name, method_name,select_type)

    if os.path.exists(filepath):
        pass
    else:
        os.makedirs(filepath)

    ft = open(str('{}{}_{}.csv'.format(filepath,dataset_name,select_ratio)), 'w', newline='')
    ft_csv = csv.writer(ft)
    if select_type=='ADD-GOOD':
        order = (-np.array(losslist)).argsort()
    elif select_type=='ADD-BAD':
        order = np.array(losslist).argsort()
    lablelist = []
    dirlist = []
    i=0
    print("Selection begins")
    for v, i in enumerate(order):

        ft_csv.writerow([img_names[i]] + [str(originlables[i])])
        dirlist.append(img_names[i])
        lablelist.append(originlables[i])

        if len(lablelist) >= 450:
            break
    select_data = list(zip(dirlist, lablelist))
    print(len(select_data))
    return select_data


def learning_loss(method_name, originlables, classifier_name, dataset_name,
                  outputlist, img_names, select_ratio, labels, select_type,
                  add_ratio, num_classes, losslist):

    global scorelist, select_data

    filepath = './Selcetion/{}/{}/{}/{}/'.format(dataset_name, classifier_name, method_name, select_type)
    os.makedirs(filepath, exist_ok=True)

    csv_path = '{}{}_{}.csv'.format(filepath, dataset_name, select_ratio)
    ft = open(csv_path, 'w', newline='')
    ft_csv = csv.writer(ft)

    if select_type == 'ADD-GOOD':
        order = (-np.array(losslist)).argsort()  
    elif select_type == 'ADD-BAD':
        order = np.array(losslist).argsort()     
    else:
        raise ValueError(f"Invalid select_type: {select_type}")

    dirlist = []
    lablelist = []

    total_select = int(add_ratio * len(img_names))
    print(f"Starting learning_loss selection; selecting {total_select} samples in total.")

    for rank, idx in enumerate(tqdm(order, desc="Selecting by loss", ncols=80)):
        ft_csv.writerow([img_names[idx], str(originlables[idx])])
        dirlist.append(img_names[idx])
        lablelist.append(originlables[idx])

        if len(lablelist) >= total_select:
            break

    ft.close()
    select_data = list(zip(dirlist, lablelist))
    print(f"Selection complete. Final number of selected samples: {len(select_data)}")
    return select_data

def learning_loss_with_merge(method_name, originlables, classifier_name, dataset_name,
                                                 outputlist, img_names, select_ratio, labels, select_type,
                                                 add_ratio, num_classes,losslist,pool_embeddinglist):

    global scorelist, select_data
    filepath = './Selcetion/{}/{}/{}/{}/'.format(dataset_name, classifier_name, method_name,select_type)

    if os.path.exists(filepath):
        pass
    else:
        os.makedirs(filepath)

    ft = open(str('{}{}_{}.csv'.format(filepath,dataset_name,select_ratio)), 'w', newline='')
    ft_csv = csv.writer(ft)
    rt = open(str('{}{}_{}_{}.csv'.format(filepath, dataset_name, select_ratio, 'remove')), 'w', newline='')
    rt_csv = csv.writer(rt)
    if select_type=='ADD-GOOD':

        order = (-np.array(losslist)).argsort()
    elif select_type=='ADD-BAD':

        order = np.array(losslist).argsort()

    lablelist = []
    dirlist = []
    i=0
    print("Selection begins")
    T = 1
    clustering = DBSCAN(eps=0.5, min_samples=4).fit(pool_embeddinglist)
    merge_feature = []
    for v, i in enumerate(order):
        if i not in merge_feature:
            ft_csv.writerow([img_names[i]] + [str(originlables[i])])
            dirlist.append(img_names[i])
            lablelist.append(originlables[i])

            for k, j in enumerate(range(len(pool_embeddinglist))):
                if j != i and clustering.labels_[j] == clustering.labels_[i]:

                    diff = pool_embeddinglist[i] - pool_embeddinglist[j]
                    distance = np.linalg.norm(diff, axis=0)
                    if distance < T:
                        merge_feature.append(j)
                        rt_csv.writerow([img_names[j]] + [str(originlables[j])])
        else:
            continue
        if len(lablelist) >= 450:
            break
    select_data = list(zip(dirlist, lablelist))
    print(select_data)
    return select_data
