import torch
import numpy as np
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.cigp_v10 import cigp
from modules.cigp_v10_mis import cigp_mis
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix


print(torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.

os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

print('testing')
print(torch.__version__)

def data_model_resgp(data_name, 
                     mask, 
                     train_begin_index = 0, 
                     test_begin_index = 0,
                     train_samples_num = 16, 
                     test_samples_num = 128, 
                     fidelity_num = 5, 
                     seed = 0,
                     need_inerp = True):
    
    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)

    '''initiate the numbers'''
    train_begin_index = 0
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num
    # fidelity_num = len(ytr)


    '''train for missing points'''

    missing_var = []
    # missing_var.append(torch.zeros(train_samples_num))
    ytr_f = []
    # ytr_f.append(ytr[0][train_begin_index:train_samples_num])
    
    for i in range(0, fidelity_num):
        # train the exist data 

        xtr_exist = []
        ytr_exist = []
        xtr_missing = []
        
        for k in range(0, train_samples_num):
            if mask[i][k] == 0:
                xtr_missing.append(xtr[0][k])
            else:
                xtr_exist.append(xtr[0][k])
                ytr_exist.append(ytr[i][k])
        
        xtr_exist = torch.stack(xtr_exist)
        ytr_exist = torch.stack(ytr_exist)
        if len(xtr_missing) != 0:
            xtr_missing = torch.stack(xtr_missing)
            
            fix = cigp(xtr_exist, ytr_exist)
            fix.train_adam(100, lr = 0.01)
            
            with torch.no_grad():
                missing_mean, missing_variance = fix.forward(xtr_missing)
            
            ytr_full = []
            m_var = []
            j_mis = 0
            j_exist = 0
            for k in range(train_samples_num):
                if mask[i][k] == 0:
                    ytr_full.append(missing_mean[j_mis])
                    m_var.append(missing_variance[j_mis][0])
                    j_mis += 1
                else:
                    ytr_full.append(ytr_exist[j_exist])
                    m_var.append(torch.tensor(0))
                    j_exist += 1

            ytr_full = torch.stack(ytr_full)
            m_var = torch.stack(m_var)
            ytr_f.append(ytr_full)
            missing_var.append(m_var)
        else:
            ytr_f.append(ytr_exist)
            missing_var.append(torch.zeros(train_samples_num))
    
    '''train lowest fidelity data'''
    lf1_model = cigp_mis(xtr[0][train_begin_index:train_samples_num], ytr_f[0][train_begin_index:train_samples_num], missing_var[0])
    lf1_model.train_adam(100, lr = 0.01)

    with torch.no_grad():
        yte_mean, yte_var = lf1_model.forward(xte[0][test_begin_index:test_samples_num])
    
    lf1_metrics = calculate_metrix(y_test=yte[0][test_begin_index:test_samples_num], y_mean_pre=yte_mean, y_var_pre = yte_var)


    '''train next fidelity data'''
    yte_res_result = []
    metrices = []
    for i in range(fidelity_num - 1):
        hf = i+1
        lf = i

        ytr_res = ytr_f[hf][train_begin_index:train_samples_num] - ytr_f[lf][train_begin_index:train_samples_num]
        yte_res = yte[hf][test_begin_index : test_samples_num]- yte[lf][test_begin_index : test_samples_num]
    
        res_model = cigp_mis(xtr[0][train_begin_index:train_samples_num], ytr_res, missing_var[hf]) 
        res_model.train_adam(100, lr = 0.01)

        with torch.no_grad():
            yte_res_mean, yte_res_var = res_model.forward(xte[0][test_begin_index:test_samples_num])

        metrices.append(calculate_metrix(y_test = yte_res, y_mean_pre=yte_res_mean, y_var_pre = yte_res_var))
        yte_res_result.append(yte_res_mean)
        yte_mean += yte_res_mean
        yte_var += yte_res_var

    
    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]
    cp_metrics = calculate_metrix(y_test = yte_test, y_mean_pre=yte_mean, y_var_pre = yte_var)

    print("loss of lowest fidelity mdoel:", lf1_metrics)
    for i in range(fidelity_num-1):
        print("loss of res-model between", i+1, "and", i+2, ":", metrices[i])
    print("loss of resgp:", cp_metrics)

    return cp_metrics



