import pandas as pd
import numpy as np
import time
import random
import os

from model_mis.conAR_mis import data_model as conAR
from model_mis.resgp_mis import data_model_resgp as resgp
from model_mis.ar_mis import ar as ar
 
'''initial setting'''
# 'Poisson_mfGent_v5', 'Heat_mfGent_v5', 'Burget_mfGent_v5_15', 'TopOP_mfGent_v6', 'plasmonic2_MF'
data_name_list = [ 'Burget_mfGent_v5_15'] 
model_list = {'conAR': conAR, 'resgp': resgp, 'ar': ar}
model_name = ['conAR', 'resgp', 'ar']

seed = list(range(50))
fidelity_num = 5

def get_cost(ones_num_list):
        cost = 0
        for fid in range(len(ones_num_list)):
                cost += ones_num_list[fid] * pow(2, fid+1)

        return cost


for data_name in data_name_list:
        train_sample_num = 64
        for _name in model_name:
                recording = {'cost':[], 'rmse':[], 'r2':[], 'nll':[], 'nrmse':[], 'time':[]}
                for k in seed:
                        random.seed(k)
                        ones_num_list = [random.randint(0,train_sample_num) for i in range(fidelity_num)]
                        print(ones_num_list)

                        # initial random mask
                        mask_matrix = []
                        for fid in range(fidelity_num):
                                mask_tem = np.zeros(train_sample_num)
                                ones_num = int(ones_num_list[fid])
                                mask_tem[:ones_num] = 1
                                np.random.seed(k * fidelity_num + fid)
                                np.random.shuffle(mask_tem)
                                mask_matrix.append(mask_tem)

                        model = model_list[_name]
                        T1 = time.time()

                        mod = model(data_name,
                                mask = mask_matrix,
                                train_begin_index = 0, 
                                test_begin_index = 0,
                                train_samples_num = train_sample_num, 
                                test_samples_num = 128, 
                                fidelity_num = fidelity_num,
                                seed = k,
                                need_inerp = True)
                        T2 = time.time()
                        
                        recording['cost'].append(get_cost(ones_num_list))
                        recording['rmse'].append(mod['rmse'])
                        recording['r2'].append(mod['r2'])
                        recording['nll'].append(mod['nll'])
                        recording['nrmse'].append(mod['nrmse'])
                        recording['time'].append(T2 - T1)

                path_csv = os.path.join( 'exp', str(_name), data_name, 'cost')
                if not os.path.exists(path_csv):
                        os.makedirs(path_csv)

                data = {'cost': recording['cost'], 
                'rmse': recording['rmse'],
                'nrmse': recording['nrmse'], 
                'r2': recording['r2'], 
                'nll': recording['nll'], 
                'time': recording['time']
                }
                df = pd.DataFrame(data) 
                df.to_csv(path_csv + '/result_' + str(len(seed)) + '.csv', index = False)