import os
import random
import time
import numpy as np
import torch as t
from torch import nn
from Core.DataSet import DataSet_python, DataSetTool
import OneTrainingTest

rand_seed = 2
gpu_str = 'cuda:0'

data_dir = 'DataSet/'
res_dir = 'Result/'
data_name = 'mushroom'

fea_ext_net_type = 'FCN'
fea_ext_net_structure = [
    {'node_num': 128, 'has_bias': True, 'has_BN': False, 'activation': nn.LeakyReLU(), 'has_dropout': False},
    {'node_num': 64, 'has_bias': True, 'has_BN': False, 'activation': nn.LeakyReLU(), 'has_dropout': False},
    {'node_num': 32, 'has_bias': True, 'has_BN': False, 'activation': nn.Softmax(dim=1), 'has_dropout': False},
]

print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + ' ' + data_name)
data_set_file = data_dir + data_name + '/' + data_name + '.mat'
data_set = DataSet_python(data_set_file=data_set_file, data_set_information=None)
rand_sam_ind = DataSetTool.LoadRandSamInd(mat_file_name=data_dir + data_name + '/' + 'kfold.mat')


fold_list = [0, 1, 2, 3]
for f in fold_list:
    print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) + ' ' + data_name + ' fold= ' + str(f))
    tow_data_set = DataSet_python.Split_Training_Test_Dataset(data_set=data_set,
                                                              rand_sam_ind=rand_sam_ind[f, :],
                                                              train_sam_ratio=0.75)
    train_data_set = tow_data_set['train_data_set']
    test_data_set = tow_data_set['test_data_set']
    result_dir = res_dir + data_name + '/fold-' + str(f) + '/'
    exp_str = time.strftime('%Y-%m-%d=%H-%M-%S', time.localtime())
    exp_str = 'rand' + str(rand_seed) + '+' + exp_str
    os.makedirs(name=result_dir + exp_str + '/')
    result_dir = result_dir + exp_str + '/'
    device = t.device('cpu')
    if t.cuda.is_available():
        device = t.device(gpu_str)
    random.seed(a=rand_seed)
    np.random.seed(seed=rand_seed)
    t.manual_seed(seed=rand_seed)
    t.set_default_dtype(t.float64)

    OneTrainingTest.run(result_dir=result_dir,
                        train_data_set=train_data_set, test_data_set=test_data_set,
                        fea_ext_net_type=fea_ext_net_type, fea_ext_net_structure=fea_ext_net_structure,
                        batch_size=2048,
                        max_epoch_num=100, decay_lr_how_often=50000, decay_lr_rate=0.9,
                        device=device)
