
import os
import sys
import torch

realpath=os.path.abspath(__file__)
_sep = os.path.sep
realpath = realpath.split(_sep)
realpath = _sep.join(realpath[:realpath.index('ML_gp')+1])
sys.path.append(realpath)

from utils.main_controller import controller
from module.ind_hogp import HOGP_MODULE
from module.ind_hogp_multi_fidelity import HOGP_MF_MODULE

interp_data = False

real_dataset = ['FlowMix3D_MF',
                'MolecularDynamic_MF', 
                'plasmonic2_MF', 
                'SOFC_MF',]

gen_dataset = ['poisson_v4_02',
                'burger_v4_02',
                'Burget_mfGent_v5',
                'Burget_mfGent_v5_02',
                # 'Heat_mfGent_v5',
                'Piosson_mfGent_v5',
                'Schroed2D_mfGent_v1',
                'TopOP_mfGent_v5',]

if __name__ == '__main__':
    for _dataset in ['SOFC_MF']:
        for _seed in [None, 0, 1, 2, 3, 4]:
            with open('record.txt', 'a') as _temp_file:
                _temp_file.write('-'*40 + '\n')
                _temp_file.write('\n')
                _temp_file.write('  Demo sGAR \n')
                _temp_file.write('  seed: {} \n'.format(_seed))
                _temp_file.write('  interp_data: {} \n'.format(interp_data))
                _temp_file.write('\n')
                _temp_file.write('-'*40 + '\n')
                _temp_file.write('-'*3 + '> Training x -> yl part\n')
                _temp_file.flush()

            controller_config = {
                'max_epoch': 1000
            } # use defualt config

            ct_module_config = {
                'dataset': {'name': 'SOFC_MF',
                            'interp_data': interp_data,

                            'seed': _seed,
                            'train_start_index': 0, 
                            'train_sample': 32, 
                            'eval_start_index': 0,
                            'eval_sample': 128,

                            'inputs_format': ['x[0]'],
                            'outputs_format': ['y[0]'],

                            'force_2d': False,
                            'x_sample_to_last_dim': False,
                            'y_sample_to_last_dim': True,
                            'slice_param': [0.6, 0.4], 
                            },
            } # only change dataset config, others use default config
            ct = controller(HOGP_MODULE, controller_config, ct_module_config)
            ct.start_train()
            ct.smart_restore_state(-1)
            ct.rc_file.write('---> final result')
            ct.rc_file.flush()
            ct.start_eval({'eval state':'final'})
            ct.rc_file.write('---> end\n\n')
            ct.rc_file.flush()

            for _sample in [4, 8, 16, 32]:
                with open('record.txt', 'a') as _temp_file:
                    _temp_file.write('-'*40 + '\n')
                    _temp_file.write('\n')
                    _temp_file.write('  Demo sGAR \n')
                    _temp_file.write('  seed: {} \n'.format(_seed))
                    _temp_file.write('  interp_data: {} \n'.format(interp_data))
                    _temp_file.write('\n')
                    _temp_file.write('-'*40 + '\n')
                    _temp_file.write('-'*3 + '> Training x -> yl part\n')
                    _temp_file.flush()

                mfct_module_config = {
                    'dataset': {'name': 'SOFC_MF',
                                'interp_data': interp_data,

                                'seed': _seed,
                                'train_start_index': 0,
                                'train_sample': _sample, 
                                'eval_start_index': 0, 
                                'eval_sample':128,

                                'inputs_format': ['x[0]', 'y[0]'],
                                'outputs_format': ['y[-1]'],

                                'force_2d': False,
                                'x_sample_to_last_dim': False,
                                'y_sample_to_last_dim': True,
                                'slice_param': [0.6, 0.4],
                                },
                } # only change dataset config, others use default config

                mfct = controller(HOGP_MF_MODULE, controller_config, mfct_module_config)
                
                with torch.no_grad():
                    # use x->yl_predict for test x+yl -> yh
                    mfct.module.inputs_eval[1] = ct.module.predict_y
                    pass

                mfct.start_train()
                mfct.smart_restore_state(-1)
                mfct.rc_file.write('---> final result')
                mfct.rc_file.flush()
                mfct.start_eval({'eval state':'final',
                                'module_name': 'SGAR',
                                'cp_record_file': True})
                mfct.rc_file.write('---> end\n\n')
                mfct.rc_file.flush()

    mfct.clear_record()