# -*- coding: utf-8 -*-
# @Time    : 2022/7/28 0:43
# @File    : mTDR.py
# @Software: PyCharm
from basic_funs import *

setup_seed(12412)
# Data Generating
dim = 5
n = 5000  # Sample size
BatchSize = 500
Val_n = 5000
Test_n = 5000

mylr, wd = 0.0001, 0.0001
width_vec1 = [2*dim, 64, 64]
num_epochs = 20000
rho1 = 0.9
rho2 = 0

rep_num = 10
result_array = np.ndarray(shape=(rep_num, 2), dtype=float)
MyNet1 = bregmanFNN(dim, width_vec1)
BestNet = bregmanFNN(dim, width_vec1)
init_num = 1

MyNet1 = bregmanFNN(dim, width_vec1)
BestNet = bregmanFNN(dim, width_vec1)

for i in range(rep_num):
    Training_q = MultivariateNormalDataset(n, dim, rho1)
    Training_p = MultivariateNormalDataset(n, dim, rho2)

    Validation_q = MultivariateNormalDataset(Val_n, dim, rho1)
    Validation_p = MultivariateNormalDataset(Val_n, dim, rho2)

    Testing_q = MultivariateNormalDataset(Test_n, dim, rho1)
    Testing_p = MultivariateNormalDataset(Test_n, dim, rho2)

    Train_q_Loader = torch.utils.data.DataLoader(Training_q.x, batch_size=BatchSize, shuffle=True)
    Train_p_Loader = torch.utils.data.DataLoader(Training_p.x, batch_size=BatchSize, shuffle=True)

    bridge_num = 5
    ltre_temp = torch.zeros(Test_n, 1)
    # convolution type
    for m in range(bridge_num):
        lin_com_a1 = m / bridge_num
        lin_com_a2 = (m + 1) / bridge_num
        Train_q_Loader = torch.utils.data.DataLoader(
            math.sqrt(1 - lin_com_a1 ** 2) * Training_q.x + lin_com_a1 * Training_p.x, batch_size=BatchSize,
            shuffle=True)
        Train_p_Loader = torch.utils.data.DataLoader(
            math.sqrt(1 - lin_com_a2 ** 2) * Training_q.x + lin_com_a2 * Training_p.x, batch_size=BatchSize,
            shuffle=True)
        new_Validation_q = math.sqrt(1 - lin_com_a1 ** 2) * Validation_q.x + lin_com_a1 * Validation_p.x
        new_Validation_p = math.sqrt(1 - lin_com_a2 ** 2) * Validation_q.x + lin_com_a2 * Validation_p.x
        bestscore = None
        for j in range(init_num):
            MyNet1.apply(weight_init)
            trainer = torch.optim.Adam(MyNet1.parameters(), lr=mylr, weight_decay=wd)
            temp = Training_breprocess(MyNet1, trainer, Train_p_Loader, Train_q_Loader, new_Validation_p,
                                       new_Validation_q, num_epochs)
            if bestscore is None or temp < bestscore:
                bestscore = temp
                BestNet.load_state_dict(torch.load('checkpoint.pt'))
        with torch.no_grad():
            ltre_temp = ltre_temp + BestNet(Testing_q.x)
    ltre_l2_loss = accuracy_eval(ltre_temp, Testing_q.logpdf)
    result_array[i, 0] = ltre_l2_loss.detach().numpy().item()

    mtre_temp = torch.zeros(Test_n, 1)
    train_delta_array = torch.zeros(n, bridge_num + 1)
    validat_delta_array = torch.zeros(Val_n, bridge_num + 1)
    # mixing type
    for m in range(bridge_num + 1):
        lin_com_a = m / bridge_num
        sampler1 = torch.distributions.bernoulli.Bernoulli(torch.tensor([lin_com_a]))
        train_delta = sampler1.sample((n,))
        print(torch.mean(train_delta).numpy())
        validat_delta = sampler1.sample((Val_n,))
        train_delta_array[:, m] = train_delta[:, 0]
        validat_delta_array[:, m] = validat_delta[:, 0]
    for m in range(bridge_num):
        train_delta1 = train_delta_array[:, m].reshape((-1, 1))
        train_delta2 = train_delta_array[:, m + 1].reshape((-1, 1))
        Train_q_Loader = torch.utils.data.DataLoader((1 - train_delta1) * Training_q.x + train_delta1 * Training_p.x,
                                                     batch_size=BatchSize, shuffle=True)
        Train_p_Loader = torch.utils.data.DataLoader((1 - train_delta2) * Training_q.x + train_delta2 * Training_p.x,
                                                     batch_size=BatchSize, shuffle=True)
        validat_delta1 = validat_delta_array[:, m].reshape((-1, 1))
        validat_delta2 = validat_delta_array[:, m + 1].reshape((-1, 1))
        new_Validation_q = (1 - validat_delta1) * Validation_q.x + validat_delta1 * Validation_p.x
        new_Validation_p = (1 - validat_delta2) * Validation_q.x + validat_delta2 * Validation_p.x
        bestscore = None
        for j in range(init_num):
            MyNet1.apply(weight_init)
            trainer = torch.optim.Adam(MyNet1.parameters(), lr=mylr, weight_decay=wd)
            temp = Training_breprocess(MyNet1, trainer, Train_p_Loader, Train_q_Loader, new_Validation_p,
                                       new_Validation_q, num_epochs)
            if bestscore is None or temp < bestscore:
                bestscore = temp
                BestNet.load_state_dict(torch.load('checkpoint.pt'))
        with torch.no_grad():
            mtre_temp = mtre_temp + BestNet(Testing_q.x)
    mtre_l2_loss = accuracy_eval(mtre_temp, Testing_q.logpdf)
    result_array[i, 1] = mtre_l2_loss.detach().numpy().item()
print(result_array)
np.save('./results.npy', result_array)