import os,sys,argparse,ot,time
import numpy as np
from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
parent_dir = os.path.abspath(os.path.join(os.getcwd(), "."))
sys.path.append(parent_dir)
print(parent_dir)
from __alg.RWD_clustering import my_gonzalez_RWD,my_k_center_RWD,robustOT
from __tool.tool import save_data, load_data,dataset_name,select_elements_byIndexes,sample_method


def main(args):
    
    
    # 解析输入参数
    clean_org_dataset_file = args.clean_org_dataset_file
    noisy_dataset_file = args.noisy_dataset_file
    k = args.k
    zeta_a = args.zeta_a
    zeta_b = args.zeta_b
    LS_Itermax = args.LS_Itermax
    noise_std = args.noise_std
    
    
    import wandb
    my_dataset_name = dataset_name(noisy_dataset_file)
    my_sample_method = sample_method(noisy_dataset_file)
    print("my_dataset_name = ",my_dataset_name)
    wandb.init(
        project="ICML25",  # 项目名称
        name=my_dataset_name,     # 实验名称
        group="",
        dir="./myWandb",
        config=args
    )
    
    
    
    time0 = time.process_time()
    _, _, _, noisy_points_list, noisy_weights_list, noisy_label_list = load_data(noisy_dataset_file)
    center_locs_list_1,center_weis_list_1,center_label_list_1 = my_gonzalez_RWD(noisy_points_list,noisy_weights_list,noisy_label_list,k,zeta_a,zeta_b)
    center_locs_list_11,center_weis_list_11,center_label_list_11 = my_k_center_RWD(noisy_points_list,noisy_weights_list,noisy_label_list,center_locs_list_1,center_weis_list_1,center_label_list_1,zeta_a,zeta_b,emd_Itermax=100000,tau=10,LS_Itermax=LS_Itermax)
    time1 = time.process_time()
    runtime = time1 - time0

        
    
    
    clean_org_points_list,clean_org_weights_list,clean_org_label_list,noisy_org_points_list,noisy_org_weights_list,noisy_org_label_list = load_data(clean_org_dataset_file)
    
    # # ### ------ del--start---
    # clean_org_points_list,clean_org_weights_list,clean_org_label_list,noisy_org_points_list,noisy_org_weights_list,noisy_org_label_list = \
    #      clean_org_points_list[:1000],clean_org_weights_list[:1000],clean_org_label_list[:1000],noisy_org_points_list[:1000],noisy_org_weights_list[:1000],noisy_org_label_list[:1000] 
    # # ### ------ del--end---

    clean_distance_matrix = [ [ot.emd2(wei_a,wei_c,ot.dist(loc_a,loc_c)) for loc_a,wei_a in zip(clean_org_points_list,clean_org_weights_list)] for loc_c,wei_c in zip(center_locs_list_11,center_weis_list_11)]
    cost_cd = np.max(np.min(clean_distance_matrix,axis=0))
    # print(np.min(clean_distance_matrix,axis=0))
    noisy_distance_matrix = [ [robustOT([loc_a,wei_a,loc_c,wei_c,zeta_a,0])[1] for loc_a,wei_a in zip(noisy_org_points_list,noisy_org_weights_list)] for loc_c,wei_c in zip(center_locs_list_11,center_weis_list_11)]
    cost_nd = np.max(np.min(noisy_distance_matrix,axis=0))
    # print(np.min(noisy_distance_matrix,axis=0))
    print("cost_cd,cost_nd = ",cost_cd,cost_nd)

    
    

    
    
    
    wandb.log({
        "runtime": runtime,
        "cost_cd": cost_cd,
        "cost_nd": cost_nd,
        "sample_method": my_sample_method,
        "sample_size": len(noisy_points_list)
    })
    
    wandb.finish()

    
    
if __name__ == "__main__":
    # 设置命令行参数解析器
    parser = argparse.ArgumentParser(description="Process noisy MNIST data and compute core-set.")
    parser.add_argument("--clean_org_dataset_file", type=str, default="my_dataset/MNIST/my_MNIST/noisy_mnist1000_3_0.5_0_1234.pkl", help="Path to the original clean dataset .pkl file.")
    # parser.add_argument("--noisy_dataset_file", type=str, default="my_dataset/MNIST/my_MNIST/noisy_mnist1000_3_0.5_0_1234/CS/cs3.0_4_800.pkl", help="Path to the noisy dataset .pkl file.")
    parser.add_argument("--noisy_dataset_file", type=str, default="my_dataset/MNIST/my_MNIST/noisy_mnist1000_3_0.5_0_1234/US/us800.pkl", help="Path to the noisy dataset .pkl file.")
    parser.add_argument("--k", type=int, default=30, help="Dimensionality parameter for core-set computation.")
    parser.add_argument("--zeta_a", type=float, default=0.1)
    parser.add_argument("--zeta_b", type=float, default=0.1)
    parser.add_argument("--noise_std", type=float, default=1)
    parser.add_argument("--LS_Itermax", type=int, default=20)

    # 解析参数并调用主函数
    args = parser.parse_args()
    main(args)


    
