import torch
import os
import sys
torch.set_default_tensor_type(torch.DoubleTensor)
sys.path.append('../')
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.cigp_v10 import cigp
from modules.cigp_v10_rho import ConstMeanCIGP
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix


print("cigp_lar:", torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 
# Fixing strange error if run in MacOS
JITTER = 1e-1
EPS = 1e-1
PI = 3.1415

def ar(data_name, 
       train_begin_index = 0,
       test_begin_index = 0,
       train_samples_num = 16, 
       test_samples_num = 128, 
       dec_rate = 0.5, 
       fidelity_num = 5,
       seed = 1,
       need_inerp = True):

    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)

    '''initiate the numbers'''
    train_begin_index = train_begin_index
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num

    train_num = [int(train_samples_num * pow(dec_rate, i)) for i in range(fidelity_num)]

    '''train low-fidelity GP'''
    model_l = cigp(xtr[0][train_begin_index:train_begin_index+train_num[0]], ytr[0][train_begin_index:train_begin_index+train_num[0]])
    model_l.train_adam(100, lr = 0.01)

    with torch.no_grad():
        yte_mean, yte_var = model_l(xte[0][test_begin_index:test_samples_num])
        # metrics_LF = calculate_metrix(y_test=yte_l, y_mean_pre=ypred_l)
    # print("loss of low-fidelity GP:",metrics_LF)


    '''train next fidelity data'''
    # preparation for high fidelity GP training
    yte_next_mean = [yte_mean]
    yte_next_var = [yte_var]
    metrices = []
    for i in range(fidelity_num - 1):
        hf = i+1
        lf = i
        yln = ytr[lf][train_begin_index:train_begin_index+train_num[lf]][0].numel()
        yhn = ytr[hf][train_begin_index:train_begin_index+train_num[lf]][0].numel()

        ytr_res = ytr[hf][train_begin_index:train_begin_index+train_num[lf]]
        yte_res = yte[hf][test_begin_index:test_samples_num]
        res_model = ConstMeanCIGP(xtr[0][train_begin_index:train_begin_index+train_num[lf]], ytr_res, yln, yhn)
        res_model.train_adam(ytr[lf][train_begin_index:train_begin_index+train_num[lf]], 100, lr = 0.01)

        with torch.no_grad():
            yte_res_mean, yte_res_var = res_model.forward(
                xte[0][test_begin_index:test_samples_num], 
                ytr[lf][train_begin_index:train_begin_index+train_num[lf]], 
                yte_next_mean[i], 
                yte_next_var[i])
        yte_next_mean.append(yte_res_mean)
        yte_next_var.append(yte_res_var)

        # metrices.append(calculate_metrix(y_test = yte_res, 
        # y_mean_pre = yte_res_mean, y_var_pre = yte_res_var))

    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]
    metrics_lar = calculate_metrix(y_test = yte_test,
                                    y_mean_pre = yte_res_mean, 
                                    y_var_pre = yte_res_var)
    #print("loss of LAR:", metrics_lar)
    #print("loss of low-fidelity:", metrics_LF)

    return metrics_lar

