import torch
torch.set_default_tensor_type(torch.DoubleTensor)
import sys
import numpy as np
import os
sys.path.append('../')

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.cigp_v10 import cigp
from modules.cigp_v10_mis import cigp_mis
from modules.cigp_v10_rho_mis import ConstMeanCIGP_mis
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix


print("cigp_lar:", torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415


def ar(data_name, 
       mask, 
       train_begin_index = 0, 
       test_begin_index = 0,
       train_samples_num = 16, 
       test_samples_num = 128, 
       fidelity_num = 5, 
       seed = 0,
       need_inerp = True):

    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)

    '''initiate the numbers'''
    train_begin_index = 0
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num
    # fidelity_num = len(ytr)


    '''train for  missing points'''

    missing_var = []
    # missing_var.append(torch.zeros(train_samples_num))
    ytr_f = []
    # ytr_f.append(ytr[0][train_begin_index:train_samples_num])
    
    for i in range(0, fidelity_num):
        # train the exist data 
        xtr_exist = []
        ytr_exist = []
        xtr_missing = []
        
        for k in range(0, train_samples_num):
            if mask[i][k] == 0:
                xtr_missing.append(xtr[0][k])
            else:
                xtr_exist.append(xtr[0][k])
                ytr_exist.append(ytr[i][k])
        
        xtr_exist = torch.stack(xtr_exist)
        ytr_exist = torch.stack(ytr_exist)
        
        if len(xtr_missing) != 0:
            xtr_missing = torch.stack(xtr_missing)
            
            fix = cigp(xtr_exist, ytr_exist)
            fix.train_adam(100, lr = 0.001)
            
            with torch.no_grad():
                missing_mean, missing_variance = fix.forward(xtr_missing)
            
            ytr_full = []
            m_var = []
            j_mis = 0
            j_exist = 0
            for k in range(train_samples_num):
                if mask[i][k] == 0:
                    ytr_full.append(missing_mean[j_mis])
                    m_var.append(missing_variance[j_mis][0])
                    j_mis += 1
                else:
                    ytr_full.append(ytr_exist[j_exist])
                    m_var.append(torch.tensor(0))
                    j_exist += 1

            ytr_full = torch.stack(ytr_full)
            m_var = torch.stack(m_var)
            ytr_f.append(ytr_full)
            missing_var.append(m_var)
        else:
            ytr_f.append(ytr_exist)
            missing_var.append(torch.zeros(train_samples_num))

    '''train low-fidelity GP'''
    lf1_model = cigp_mis(xtr[0][train_begin_index:train_samples_num], ytr[0][train_begin_index:train_samples_num], missing_var[0])
    lf1_model.train_adam(100, lr = 0.001)

    with torch.no_grad():
        yte_mean, yte_var = lf1_model.forward(xte[0][test_begin_index:test_samples_num])
    
    lf1_metrics = calculate_metrix(y_test=yte[0][test_begin_index:test_samples_num], y_mean_pre=yte_mean, y_var_pre = yte_var)


    '''train res fidelity data'''
    yte_next_mean = [yte_mean]
    yte_next_var = [yte_var]
    metrices = []
    for i in range(fidelity_num - 1):
        hf = i+1
        lf = i
        yln = ytr[lf][train_begin_index:train_samples_num][0].numel()
        yhn = ytr[hf][train_begin_index:train_samples_num][0].numel()

        ytr_res = ytr_f[hf][train_begin_index:train_samples_num]
        # ytr_res = ytr_f[hf][train_begin_index:train_samples_num] - ytr_f[lf][train_begin_index:train_samples_num]
        yte_res = yte[hf][test_begin_index : test_samples_num]
        # yte_res = yte[hf][test_begin_index : test_samples_num]- yte[lf][test_begin_index : test_samples_num]

        res_model = ConstMeanCIGP_mis(xtr[0][train_begin_index:train_samples_num],ytr_res, yln, yhn, missing_var[hf]) # matters

        res_model.train_adam(ytr_f[lf][train_begin_index:train_samples_num] , 100, lr = 0.001)

        with torch.no_grad():
            print("begin test")
            yte_res_mean, yte_res_var = res_model.forward(
                xte[0][test_begin_index:test_samples_num], 
                ytr_f[lf][train_begin_index:train_samples_num], 
                yte_next_mean[i], 
                yte_next_var[i])
        yte_next_mean.append(yte_res_mean)
        yte_next_var.append(yte_res_var)


        # yte_res_var = yte_res_var.expand_as(yte_res)
        metrices.append(calculate_metrix(y_test = yte_res, 
        y_mean_pre = yte_res_mean, y_var_pre = yte_res_var))
        # yte_res_result.append(yte_res_mean)
        # yte_mean += yte_res_mean
        # yte_var += yte_res_var

    # yte_var = yte_var.expand_as(yte_test)
    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]
    cp_metrics = calculate_metrix(y_test = yte_test, 
                                  y_mean_pre = yte_res_mean, 
                                  y_var_pre = yte_res_var)

    print("loss of lowest fidelity mdoel:", lf1_metrics)
    for i in range(fidelity_num-1):
        print("loss of res-model between", i+1, "and", i+2, ":", metrices[i])
    print("loss of our model:", cp_metrics)

    return cp_metrics

