import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

from pbr import PBR
from dkwt import DKWT
import numpy as np
import pandas as pd
import gc

# parent_dir = '/root/main/ranking/sim'
# folder = "/root/main/ranking/sim/random_1000_16_100"
# folder = "/root/main/ranking/sim/separated_8_100"
# folder = "/root/main/ranking/sim/separated_32_100"
# folder = "/root/main/ranking/sim/g2_32_100"
folder =[f"/root/main/ranking/sim/g2_32_{overlap}_trunc" for overlap in [10,30,70,100]]
# folder = "/root/main/ranking/sim/synthetic_cluster"

if type(folder) == str:
    folder_list = [folder]
else:
    folder_list = folder

for folder in folder_list[:]:
    print(f'=======================RUNNING FOLDER: {folder} ============================')
    vecs = np.loadtxt(os.path.join(folder, 'vecs.txt'))
    q = np.loadtxt(os.path.join(folder, 'q.txt'))

    def append_dict_to_df(df, dict_row):
        # Convert the dict to DataFrame
        dict_df = pd.DataFrame([dict_row])
        # Append the dict DataFrame to the original DataFrame
        df = pd.concat([df, dict_df], ignore_index=True)
        return df

    #var_tup = ('eps', [0.1, 0.05, 0.02])
    #var_tup = ('n', [50,200,500])
    #var_tup = ('k', [5, 10, 20, 40])
    #var_tup = ('sharpness', [1,5])

    var_tup_list = [('eps', [0.2, 0.1, 0.05, 0.02]), 
                    ('n', [50,200,500]), 
                    ('k', [5, 10, 20, 40]), 
                    ('sharpness', [1,5])]
    var_tup_list = var_tup_list[:1]

    for var_tup in var_tup_list:
        print(var_tup)


        for v in var_tup[1]:
            for i in range(q.shape[0]):
                print(f'\n\n\n -------ITERATION {i}----------------------- \n\n\n')
                if os.path.exists(os.path.join(folder, f'dkwt_lenient_results_{var_tup[0]}.csv')):
                    results = pd.read_csv(os.path.join(folder, f'dkwt_lenient_results_{var_tup[0]}.csv')).sort_values(['q_idx', var_tup[0]], inplace=False)
                    if len(results.loc[(results['q_idx']==i) & (results[var_tup[0]]==v)])!=0:
                        print('Entry already exists, skipping ...')
                        continue
                else:
                    results = pd.DataFrame()
                try:
                    if var_tup[0] == 'n':
                        dkwt = DKWT(vecs[:v])
                    else:
                        dkwt = DKWT(vecs, **{var_tup[0]: v}, lenient=True)
                    dkwt.run_sim(q[i,:])
                    tmp = dkwt.get_result_dict()
                    tmp['q_idx'] = i
                    results = append_dict_to_df(results, tmp)
                    results.sort_values(['q_idx', var_tup[0]], inplace=False).to_csv(os.path.join(folder, f'dkwt_lenient_results_{var_tup[0]}.csv'), index=False)
                    del results
                    del dkwt
                    del tmp
                    gc.collect()
                except Exception as e:
                    print(f'{e}: continuing to next iteration')

    