import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

from pbr import PBR
from dkwt import DKWT
import numpy as np
import pandas as pd
import gc

# parent_dir = '/root/main/ranking/sim'
# folder = "/root/main/ranking/sim/random_1000_16_100"
# folder = "/root/main/ranking/sim/separated_8_100"
# folder = "/root/main/ranking/sim/separated_32_100"
# folder = "/root/main/ranking/sim/g2_32_100"
# folder =[f"/root/main/ranking/sim/g2_32_{overlap}_trunc" for overlap in [10,30,70,100]]
# folder = "/root/main/ranking/sim/synthetic_cluster"
folder = ["/root/main/ranking/sim/random_1000_16_100", "/root/main/ranking/sim/separated_32_100"] + [f"/root/main/ranking/sim/g2_32_{overlap}_trunc" for overlap in [10,30,70,100]]

if type(folder) == str:
    folder_list = [folder]
else:
    folder_list = folder

for folder in folder_list[:]:
    print(f'=======================RUNNING FOLDER: {folder} ============================')
    vecs = np.loadtxt(os.path.join(folder, 'vecs.txt'))
    q = np.loadtxt(os.path.join(folder, 'q.txt'))

    def append_dict_to_df(df, dict_row):
        # Convert the dict to DataFrame
        dict_df = pd.DataFrame([dict_row])
        # Append the dict DataFrame to the original DataFrame
        df = pd.concat([df, dict_df], ignore_index=True)
        return df

    var_tup = ('eps', [0.2, 0.1, 0.05, 0.02])
    #var_tup = ('n', [50,200,500])
    #var_tup = ('k', [5, 10, 20, 40])
    #var_tup = ('sharpness', [1,5])

    print(var_tup)


    for v in var_tup[1]:
        for i in range(q.shape[0]):
            print(f'\n\n\n -------ITERATION {i}----------------------- \n\n\n')
            if os.path.exists(os.path.join(folder, f'dkwt_final_results_{var_tup[0]}.csv')):
                results = pd.read_csv(os.path.join(folder, f'dkwt_final_results_{var_tup[0]}.csv')).sort_values(['q_idx', var_tup[0]], inplace=False)
                if len(results.loc[(results['q_idx']==i) & (results[var_tup[0]]==v)])!=0:
                    print('Entry already exists, skipping ...')
                    continue
            else:
                results = pd.DataFrame()
            try:
                if var_tup[0] == 'n':
                    dkwt = DKWT(vecs[:v])
                else:
                    dkwt = DKWT(vecs, **{var_tup[0]: v})
                dkwt.run_sim(q[i,:])
                tmp = dkwt.get_result_dict()
                tmp['q_idx'] = i
                results = append_dict_to_df(results, tmp)
                results.sort_values(['q_idx', var_tup[0]], inplace=False).to_csv(os.path.join(folder, f'dkwt_final_results_{var_tup[0]}.csv'), index=False)
                del results
                del dkwt
                del tmp
                gc.collect()
            except Exception as e:
                print(f'{e}: continuing to next iteration')

    