# The 2nd step: determine the best parameters based on RF accuracy.
# Run after Run_adagraph.py
import scipy.io as sio
import numpy as np

import sys
sys.path.append('../')
from Utils.model_evaluation import run_baseline,check_best
import torch


if __name__ == "__main__":

    random_iter = 1
    device = torch.device('cpu')

    tasks = ['classification']
    baseline_models = ['RF']
    data_list = ['madelon']
    lrs = [1e-4,1e-3,1e-2,1e-1,1e0,1e1]    
    epsilons = [1e-3,1e-2,1e-1]
    num_neighbors = [2,3,4,5]

    for task, baseline_model in zip(tasks,baseline_models):
        print('===========>'+task+':',flush=True)
        if task in ['clustering','classification']:
            higher_flag = True
        else:
            higher_flag = False
        
        for fname in data_list:
            print('======>'+fname+':',flush=True)
            if fname == 'madelon':
                feanums = [5,10,15,20] 
            else:
                feanums = [25,50,75,100,150,200,300]
            # determine the best paremeter under different neighbour numbers
            best_param_all_neighbors = []
            for num_neighbor in num_neighbors:
                data = sio.loadmat('./Results_all_params/AdaGraph_'+fname+'_num_neighbor'+str(num_neighbor)+'.mat')
                indices = data['indices']
                X_train = data['X_train']
                X_test = data['X_test']
                y_train = data['y_train']
                y_test = data['y_test']
                # extract selected features
                X_train_ori = X_train
                X_test_ori = X_test

                assert indices.shape[0]==len(feanums) and indices.shape[1]==len(lrs) and indices.shape[2]==len(epsilons)

                # start evaluation
                total_res_mean = []
                total_res_std = []
                best_params = []
                for i,numfea in enumerate(feanums):
                    if higher_flag:
                        best_res_mean = [0,]
                    else:
                        best_res_mean = [9999,]
                    best_res_std = [9999,]
                    for j, lr in enumerate(lrs):
                        for k, epsilon in enumerate(epsilons):
                            selected_ind = indices[i,j,k][:numfea].squeeze()
                            X_train = X_train_ori[:,selected_ind]
                            X_test = X_test_ori[:,selected_ind]
                            
                            # evaluate features based on the baseline model with repeated trials
                            iter_res = []
                            for iter in range(random_iter):
                                if task in ['clustering','classification']:
                                    res = run_baseline(X_train,X_test,y_train,y_test,baseline_model,iter)
                                else:
                                    res = run_baseline(X_train,X_test,X_train_ori,X_test_ori,baseline_model, iter, device = device)
                                iter_res.append(res)
                            # compare and recod the best result
                            if check_best(best_res_mean,best_res_std,iter_res,higher_flag):
                                best_res_mean,best_res_std = np.mean(iter_res,axis = 0), np.std(iter_res,axis = 0)
                                best_para = [lr,epsilon]
                    total_res_mean.append(best_res_mean)
                    total_res_std.append(best_res_std)
                    best_params.append(best_para)

            # determine the best neighbour number and its corresponding parameter combinations
            best_param_all_neighbors = []
            # start evaluation
            total_res_mean = []
            total_res_std = []
            best_params = []
            for i,numfea in enumerate(feanums):
                if higher_flag:
                    best_res_mean = [0,]
                else:
                    best_res_mean = [9999,]
                best_res_std = [9999,]
                for j,num_neighbor in enumerate(num_neighbors):
                    data = sio.loadmat('./Selected_optimal_params/Best_para_'+baseline_model+'_Our_'+fname+'_num_neighbor'+str(num_neighbor)+'.mat')
                    res = [data['total_res_mean'][i],]
                    if check_best(best_res_mean,best_res_std,res,higher_flag):
                        best_res_mean,best_res_std = np.mean(res,axis = 0), np.std(res,axis = 0)
                        best_para = [num_neighbor,]+data['best_params'][i].tolist()
                best_params.append(best_para)
                total_res_mean.append(best_res_mean)
                total_res_std.append(best_res_std)
            sio.savemat('./Selected_optimal_params/Best_para_'+baseline_model+'_Our_'+fname+'_all.mat',{'best_params':best_params,'total_res_mean':total_res_mean})
