# ## Verify the effect of UFS

# %%
import os
import sys
sys.path.append('../')

import numpy as np
import scipy
from scipy.stats import norm
import scipy.io as sio
from sklearn.datasets import  make_circles

import torch
from torch.utils import data

from Modules.ada_graph import *
from Utils.utils import set_seed, try_gpu
from Utils.run_experiments import *


# %% Generate data
def create_circles_dataset(n, p):
    relevant, y = make_circles(n_samples=n, shuffle=False, noise=0.03, factor = 0.5) 
    #print(y.shape)
    noise_vector = norm.rvs(loc=0, scale=1, size=[n,p-2])
    data = np.concatenate([relevant, noise_vector], axis=1)
    #print(data.shape)
    return data, y 


# %%
if __name__ == '__main__':
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    # set seed
    seed_num = 42
    set_seed(seed_num)
    # deploy network on multiple GPUs
    gpu_list = [0]                                
    devices = [try_gpu(i) for i in gpu_list]

    n_size = 500
    p_size = 20
    X_data, y_data = create_circles_dataset(n_size, p_size)

    # %%
    XN=scipy.stats.zscore(X_data, axis=0)


    # %%
    X_train = torch.tensor(XN)
    Y_train = torch.tensor(y_data)
    train_data = (X_train,Y_train)


    args = {}
    # data artributes
    args['fname'] = 'two_circles'
    args['train_num'] = X_train.shape[0]
    args['fea_dim'] = X_train.shape[1]
    # task parameter
    args['selected_num'] = 2
    args['save_model'] = False
    args['save_period'] = None

    # learning parameter
    args['num_epochs']= 250
    args['num_neighbours'] = 5
    # top-k parameter
    args['num_iter'] = 200
    args['manual_flag'] = True

    # pretrain parameter
    args['pretrain_flag'] = False
    args['pretrained_model'] = ""

    args['lr'] = 0.1
    args['epsilon'] = 0.001
    net = Ada_Graph_Fixed_Concrete_no_orth(args['train_num'], args['fea_dim'], args['selected_num'], 
                                args['num_neighbours'], args['epsilon'], args['num_iter'], 
                                args['manual_flag'])
    
    I, S, FS_mat = train(net,train_data,devices, args,XN,toy = True)

    # calcluate and sort the feature importance
    score = I.squeeze()
    sorted, selected_ind = torch.sort(score,descending=True)
    res = selected_ind.cpu().detach().numpy()[:args['selected_num']]

    I = I.cpu().detach().numpy()
    S = S.cpu().detach().numpy()
    FS_mat = FS_mat.cpu().detach().numpy()

    sio.savemat('toy_circles_no_orth.mat',{'I':I,'S':S,'FS_mat':FS_mat,'X':XN,'y':y_data})