from tqdm import tqdm
import pandas as pd
from muss.src.candidate_retrieval.baselines import (
    run_experiment_clustering,
    run_experiment_DPP,
    run_experiment_MMR,
    run_random_selection,
    run_DGDS
)
from muss.src.candidate_retrieval.helpers import (
    run_MUSS,
    save_result
)

MMR_DIVERSITY_TYPE_SUM = "sum"
MMR_DIVERSITY_TYPE_MIN = "min"

# store the result
output_df = pd.DataFrame()

# # load home data
selected_df = pd.read_parquet('../data/home_dataset.parquet', engine='auto')

dataset_name = 'home'

print("dataset_name", dataset_name, selected_df.shape)

# number of selected items
k = 500

# run the main MUSS algorithm ===============================================================
for rand_val in tqdm([1,2,3,4,5]):
    for lamb in tqdm([0.1,0.3,0.5,0.7,0.9]):
        for lamb_c in [0.1,0.3,0.5,0.7,0.9]:
    
            for m_value in [50,100,200]:
                for n_partition_value in [100,200,500]:
                    
                    if n_partition_value > selected_df.shape[0]/5:
                        continue
                    if m_value >= n_partition_value:
                        continue
    
                    eval_output, _ = run_MUSS(selected_df, k = k, k_within = 50, m = m_value, 
                              n_partitions = n_partition_value, n_jobs=25, lamb = lamb, 
                                              lamb_c = lamb_c, dataset = dataset_name,
                                              random_state=rand_val, div_type=MMR_DIVERSITY_TYPE_SUM)
    
                    output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)

save_result(output_df)

# DGDS baseline =====================================================
for rand_val in tqdm([1,2,3,4,5]):
    for lamb in tqdm([0.1,0.3,0.5,0.7,0.9]):
        for n_partition_value in [100,200,500]:
    
            if n_partition_value > selected_df.shape[0]/5:
                continue

            if selected_df.shape[0]/n_partition_value<=30:
                continue
    
            eval_output = run_DGDS(selected_df, k = k, k_within = 50, n_partitions = n_partition_value, \
                                              n_jobs=25, lamb = lamb, dataset = dataset_name,\
                                   random_state=rand_val, div_type=MMR_DIVERSITY_TYPE_SUM)
    
            output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)

save_result(output_df)


# run clustering baseline===========================================================================
for rand_val in tqdm([1,2,3,4,5]):
    eval_output = run_experiment_clustering(selected_df, k=k, 
                                            random_state = rand_val, 
                                            dataset = dataset_name)
    output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)
save_result(output_df) # save result into json
    


# run random baseline ==================================================================================
for rand_val in tqdm([1,2,3,4,5]):
    eval_output = run_random_selection(selected_df,k=k,random_state = rand_val, dataset = dataset_name)
    output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)
save_result(output_df) # save result into json


# run DPP if the datasize < 50k ======================================================================
if selected_df.shape[0]<50000:
    for rand_val in tqdm([1,2,3,4,5]):
        eval_output = run_experiment_DPP(selected_df,k=k,dataset = dataset_name, random_state = rand_val)
        output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)
    save_result(output_df) # save result into json

# run MMR baseline ===============================================================================
for lamb in tqdm([0.1,0.3,0.5,0.7,0.9]):
    eval_output = run_experiment_MMR(selected_df, k=k, topk=None, lamb=lamb, dataset = dataset_name)
    output_df = pd.concat ([output_df,eval_output]).drop_duplicates().reset_index(drop=True)
save_result(output_df) # save result into json

