import torch
import numpy as np
import os
import sys
import random

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.cigp_v10 import cigp
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix


print(torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

print('testing')
print(torch.__version__)


def resgp(data_name, 
          train_begin_index = 0,
          test_begin_index = 0,
          train_samples_num = 16,
          test_samples_num = 128, 
          dec_rate = 0.5,
          fidelity_num = 5,
          seed = 1,
          need_inerp = True):

    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)

    '''initiate the numbers'''
    train_begin_index = train_begin_index
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num

    train_num = [int(train_samples_num * pow(dec_rate, i)) for i in range(fidelity_num)]

    '''train lowest fidelity data'''
    lf1_model = cigp(xtr[0][train_begin_index:train_begin_index+train_num[0]], ytr[0][train_begin_index:train_begin_index+train_num[0]])
    lf1_model.train_adam(100, lr = 0.01)

    with torch.no_grad():
        yte_mean, yte_var = lf1_model.forward(xte[0][test_begin_index:test_samples_num])
    
    lf1_metrics = calculate_metrix(y_test=yte[0][test_begin_index:test_samples_num], y_mean_pre=yte_mean, y_var_pre = yte_var)

    '''train next fidelity data'''
    yte_res_result = []
    metrices = []
    for i in range(fidelity_num - 1):
        hf = i+1
        lf = i

        ytr_res = ytr[hf][train_begin_index:train_begin_index+train_num[-1]] - ytr[lf][train_begin_index:train_begin_index+train_num[-1]]
        # used to test different fidelity train consequence
        yte_res = yte[hf][test_begin_index : test_samples_num]- yte[lf][test_begin_index : test_samples_num]

        res_model = cigp(xtr[0][train_begin_index:train_begin_index+train_num[-1]], ytr_res)
        res_model.train_adam(100, lr = 0.01)

        with torch.no_grad():
            yte_res_mean, yte_res_var = res_model.forward(xte[0][test_begin_index:test_samples_num])

        metrices.append(calculate_metrix(y_test = yte_res, y_mean_pre = yte_res_mean, y_var_pre = yte_res_var))
        yte_res_result.append(yte_res_mean)
        yte_mean += yte_res_mean

    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]
    resgp_metrics = calculate_metrix(y_test = yte_test, y_mean_pre = yte_mean, y_var_pre = yte_var)

    print("loss of lowest fidelity mdoel:", lf1_metrics)
    for i in range(fidelity_num-1):
        print("loss of res-model between", i+1, "and", i+2, ":", metrices[i])
    print("loss of our model:", resgp_metrics)

    return resgp_metrics





