import numpy as np
import torch
from NetEM_final import NetEM_final
from utils import *

import matplotlib.pyplot as plt
import copy
import time

np.random.seed(233)
torch.manual_seed(233)

dist = "exp"
delta = 1

nd = 500

A, B = np.load("true_A.npz")['A'], np.load("true_B.npz")['B']

B_use = B.copy()
np.fill_diagonal(B_use, 0)
supp_A = (A > 0).astype(int)

num = 5

A_mae_collect = np.zeros(num)
B_mae_collect = np.zeros(num)
B_Acc_collect = np.zeros(num)
B_Pre_collect = np.zeros(num)
B_Recall_collect = np.zeros(num)
pi_error_collect = np.zeros(num)
time_vec = np.zeros(num)

nc = 50000
t = 10

P_pathway = np.ones(nd) * 0.5

for ii in range(num):
    cascades = np.load(dist + "_cascade_{}.npz".format(ii+1))['cascades']

    A_ini = np.load("res_{}.npz".format(ii+1))['A']
    B_ini = np.load("res_{}.npz".format(ii+1))['B']

    # A_ini = np.maximum(np.random.randn(nd, nd), 0) * A
    # B_ini = np.maximum(np.random.randn(nd, nd), 0) * B
    # B_ini += np.ones_like(A) * 0.005

    ## optimization

    A_update = A_ini.copy()
    B_update = B_ini.copy()
    B_update_use = B_ini.copy()
    np.fill_diagonal(B_update_use, 0)

    P_pathway_update = np.random.uniform(0.3, 0.7, nd)
    P_pathway_ind = np.random.uniform(0.3, 0.7, (nc, nd))

    print("experiment", ii, "started\n")

    start = time.time()

    ############################################### MNCS EM ###############################################
    model = NetEM_final(dist, delta, lr_A=1e-5, lr_B=4.25e-6, penl_svd=0.01, penl_l1_B=150,
                 penl_l1_A=150, hard_thres=0.01, max_Iter=500, eps=1e-7, svd_freq=1, loss_freq=500, batch_size=4000, rank_guess=5, oversample=15)
    res = model.em_optimize(A_update, B_update, P_pathway_update, P_pathway_ind, supp_A, cascades, t)
    #######################################################################################################

    A_update = res['A']; B_update = res['B']; B_update_use = res['B_use']; P_pathway_update = res['p']; loss = res['loss']

    end = time.time()
    print("time:", end - start)
    time_vec[ii] = end - start

    loss = [tensor.detach().cpu().numpy() for tensor in loss]

    plt.plot(loss)
    plt.show()

    np.savez_compressed("our_res_{}.npz".format(ii+1), A=A_update, B=B_update, p=P_pathway_update, loss=loss, time=time_vec)

    A_mae_collect[ii] = get_normalized_mae(A_update, A, hard_thres=0)
    B_mae_collect[ii] = get_normalized_mae(B_update_use, B_use, hard_thres=0)
    B_Acc_collect[ii] = get_acc(B_update_use, B_use, 0)
    B_Pre_collect[ii] = get_pre(B_update_use, B_use, 0)
    B_Recall_collect[ii] = get_recall(B_update_use, B_use, 0)
    pi_error_collect[ii] = np.mean(abs((P_pathway_update - P_pathway) / P_pathway))

    print("experiment", ii, "finished\n")

print("Our Result:")
print("mean:", np.mean(A_mae_collect), np.mean(B_mae_collect), np.mean(B_Acc_collect), np.mean(B_Pre_collect), np.mean(B_Recall_collect))
print("sd:", np.std(A_mae_collect), np.std(B_mae_collect), np.std(B_Acc_collect), np.std(B_Pre_collect), np.std(B_Recall_collect))
print("pi_error:", np.mean(pi_error_collect), np.std(pi_error_collect), "\n")