import itertools
import os
import pickle
import sys

import torch.autograd

from party.vfedcd_party import VFedCDServer
from utils.basic_functions import compute_metrics
from evaluates.attacks.unsplit import USLinear, model_inversion_stealing, pearson_correlation

sys.path.append(os.pardir)
import tensorflow as tf

import time
import copy

# from models.vision import resnet18, MLP2

# from evaluates.attacks.attack_api import apply_attack
from evaluates.defenses.defense_functions import *
from utils.communication_protocol_funcs import compress_pred
import networkx as nx

tf.compat.v1.enable_eager_execution()


class MainTaskVFedCD(object):
    def __init__(self, args):
        self.args = args
        self.k = args.k
        self.device = args.device
        self.dataset_name = args.dataset
        self.save_path = args.save_path
        if os.path.exists(self.save_path):
            assert 1 == 2, "save_path exists, please run after 1 minute"
        else:
            os.mkdir(self.save_path)
        self.log_file = open(os.path.join(self.save_path, " log.txt"), "a", encoding="utf-8")
        # self.train_dataset = args.train_dst
        # self.val_dataset = args.test_dst
        # self.half_dim = args.half_dim
        self.epochs = args.main_epochs
        self.lr = args.main_lr
        self.batch_size = args.batch_size
        self.models_dict = args.model_list
        # self.num_classes = args.num_classes
        # self.num_class_list = args.num_class_list
        self.num_classes = args.num_classes
        self.exp_res_dir = args.exp_res_dir

        self.exp_res_path = args.exp_res_path
        self.parties = args.parties

        self.Q = args.Q  # FedBCD

        self.parties_data = None
        self.gt_one_hot_label = None
        self.clean_one_hot_label = None
        self.pred_list = []
        self.pred_list_clone = []
        self.pred_gradients_list = []
        self.pred_gradients_list_clone = []

        # FedBCD related
        self.local_pred_list = []
        self.local_pred_list_clone = []
        self.local_pred_gradients_list = []
        self.local_pred_gradients_list_clone = []

        self.loss = None
        self.flag = 1
        self.stopping_iter = 0
        self.stopping_time = 0.0
        self.stopping_commu_cost = 0
        self.communication_cost = 0

        # causal
        self.stage = args.stage
        self.alpha = args.causal['alpha'][self.stage]
        self.beta = args.causal['beta'][self.stage]
        self.eta = args.causal['eta'][self.stage]
        self.gamma_from = args.causal['gamma_from'][self.stage]
        self.gamma_increase = args.causal['gamma_increase'][self.stage]
        self.gamma = 0
        self.gamma_cap = None
        self.threshold = args.causal['threshold'][self.stage]
        self.mask_threshold = args.causal['mask_threshold'][self.stage]
        self.mask = args.causal['mask'][self.stage]
        if self.mask == 'none':
            self.mask = None
        self.freeze_gamma_at_dag = args.causal['freeze_gamma_at_dag'][self.stage] == 1
        self.freeze_gamma_threshold = args.causal['freeze_gamma_threshold'][self.stage]
        if self.freeze_gamma_at_dag:
            assert self.freeze_gamma_threshold <= self.threshold, "freeze_gamma_threshold({}) should be smaller than threshold({})".format(
                self.freeze_gamma_at_dag, self.threshold
            )
        self.adj_p = args.causal['adj_p']
        self.d = sum(args.dataset_split['dims'])
        self.server = VFedCDServer(args, self.get_adjacency_matrix())
        self.out_dims = args.dataset_split['dims']
        self.acc_out_dims = [sum(self.out_dims[:ind]) for ind in range(self.k)] + [sum(self.out_dims)]

        # Early Stop
        self.early_stop_threshold = args.early_stop_threshold
        self.final_epoch = 0
        self.current_epoch = 0
        self.current_step = 0

        # some state of VFL throughout training process
        self.first_epoch_state = None
        self.final_state = None
        # self.final_epoch_state = None # <-- this is save in the above parameters

        self.debug = args.debug
        if self.debug:
            self.debug_init("model3.ckpt")

        self.num_update_per_batch = args.num_update_per_batch
        self.num_batch_per_workset = args.Q  # args.num_batch_per_workset
        self.max_staleness = self.num_update_per_batch * self.num_batch_per_workset
        self.init_encrypted_local_model()
        self.us_mse = None

    def log(self, msg):
        print(msg)
        self.log_file.write(msg + "\n")
        self.log_file.flush()

    def log_close(self):
        self.log_file.close()
        self.log_file = None

    def pickle_save(self, obj, path):
        full_path = os.path.join(self.save_path, path)
        with open(full_path, "wb") as f:
            pickle.dump(obj, f)

    def torch_save(self, model, path):
        full_path = os.path.join(self.save_path, path)
        # 确保保存路径的目录存在
        directory = os.path.dirname(full_path)
        if not os.path.exists(directory):
            os.makedirs(directory)
        torch.save(model.state_dict(), full_path)

    def init_encrypted_local_model(self):
        for encrypted_by_party in range(self.k):
            encrypted_models = self.parties[
                encrypted_by_party].give_encrypt_local_model()  # key: belong_to_model value:encrypted_model;
            for belong_to_model in encrypted_models.keys():
                assert belong_to_model != encrypted_models, "trying to send model encrypted by party {} to model {}".format(
                    encrypted_by_party, belong_to_model
                )
                self.parties[belong_to_model].receive_encrypt_local_model(encrypted_by_party,
                                                                          encrypted_models[belong_to_model])

    def pred_transmit(self):  # Active party gets pred from passive parties
        for belong_to_model in range(self.k):
            plain_pred, encrypted_pred, encrypted_pred_ss_mass, dag_pred, dag_pred_clone = self.parties[
                belong_to_model].give_pred()
            for encrypted_by_party in range(self.k):
                if encrypted_by_party != belong_to_model:
                    self.parties[encrypted_by_party].receive_pred_ss_mass(belong_to_model,
                                                                          encrypted_pred_ss_mass[encrypted_by_party])
            if self.args.apply_local_dp and self.args.apply_server_dp:
                dag_pred_clone = add_differential_privacy(dag_pred_clone, self.args)
            if self.args.server.communication_protocol in ['Quantization', 'Topk']:
                dag_pred_clone = compress_pred(self.args.server, dag_pred_clone, None, None, None).to(self.args.device)
            dag_pred_clone = torch.autograd.Variable(dag_pred_clone, requires_grad=True).to(self.args.device)
            self.server.receive_dag(dag_pred_clone, belong_to_model)
        ss_z_clone_list = []
        for ik in range(self.k):
            ss_z, ss_z_clone = self.parties[ik].give_pred_ss_z()
            if self.args.apply_local_dp and self.args.apply_client_dp:
                ss_z_clone = add_differential_privacy(ss_z_clone, self.args)
            ss_z_clone_list.append(ss_z_clone)
        z = sum(ss_z_clone_list)
        z = torch.autograd.Variable(z, requires_grad=True).to(self.args.device)
        if self.args.apply_global_dp and self.args.apply_client_dp:
            z = add_differential_privacy(z, self.args)
        for from_global_model in range(self.k):
            self.parties[from_global_model].receive_pred(
                z[:, self.acc_out_dims[from_global_model]: self.acc_out_dims[from_global_model + 1], :])

    # def get_extra_loss(self, return_detailed_losses=False):
    #     total_extra_loss = 0
    #     total_l1_loss = 0
    #     total_l2_loss = 0
    #     total_dag_loss = 0
    #     for ik in range(self.k):
    #         extra_loss, detail_loss = self.parties[ik].local_model[self.now_active_party].extra_loss(
    #             alpha=self.alpha, beta=self.beta, return_detailed_losses=True)
    #         total_extra_loss += extra_loss
    #         total_l1_loss += detail_loss['l1']
    #         total_l2_loss += detail_loss['l2']
    #     extra_loss, detail_loss = self.parties[self.now_active_party].global_model.extra_loss(beta=self.beta,
    #                                                                                                return_detailed_losses=True)
    #     total_extra_loss += extra_loss
    #     total_l2_loss += detail_loss['l2']
    #     if self.dag_penalty_flavor == "logdet":
    #         # dag_reg = self.dag_reg()
    #         assert 1 == 2, "logdet not supported yet"
    #     elif self.dag_penalty_flavor in ("scc", "power_iteration"):
    #         dag_reg = self.dag_reg_power_grad()
    #     elif self.dag_penalty_flavor == "none":
    #         dag_reg = 0.0
    #     else:
    #         assert 1 == 2, "unknown dag_penalty_flavor:{}".format(self.dag_penalty_flavor)
    #     detail_loss = {'dag': dag_reg.detach() if type(dag_reg) != float else torch.zeros(1)}
    #     total_extra_loss += dag_reg
    #     total_dag_loss += detail_loss['dag']
    #
    #     if return_detailed_losses:
    #         return total_extra_loss, {
    #             'l1': total_l1_loss,
    #             'l2': total_l2_loss,
    #             'dag': total_dag_loss,
    #         }
    #     else:
    #         return total_extra_loss

    # def dag_reg(self):
    #     A = self.get_adjacency_matrix() ** 2
    #     h = -torch.slogdet(self.identity - A)[1]
    #     return h

    def gradient_transmit(self):  # Active party sends gradient to passive parties
        for from_global_model in range(self.k):
            gradient_z_clone, reshaped_encrypted_gradient_z_clone = self.parties[
                from_global_model].give_gradient()  # gradient_clone

            for belong_to_model in range(self.k):
                if belong_to_model != from_global_model:  # passive give encrypted_grad_ss_mass
                    self.parties[belong_to_model].receive_gradient(
                        reshaped_encrypted_gradient_z_clone, from_global_model)
                    encrypted_grad_ss_mass = self.parties[belong_to_model].give_encrypted_grad_ss_mass(from_global_model)
                    self.parties[from_global_model].receive_grad_ss_mass(belong_to_model, encrypted_grad_ss_mass)
            self.parties[from_global_model].update_local_gradient(gradient_z_clone)

        dag_gradients_clone = self.server.give_gradients(self.gamma)

        # active party update local gradient, transfer gradient to passive parties
        for ik in range(self.k):
            self.parties[ik].receive_dag_gradient(dag_gradients_clone[ik])

    # def LR_Decay(self, i_epoch):
    #     for ik in range(self.k):
    #         self.parties[ik].LR_decay(i_epoch)
    #     self.parties[self.now_active_party].global_LR_decay(i_epoch)

    def train_batch(self, parties_data):
        # prepare
        recon_losses = []
        l1_losses = []
        l2_losses = []
        dag_losses = []
        # allocate data to each party
        for ik in range(self.k):
            self.parties[ik].obtain_local_data(parties_data[ik].clone())

        # ====== normal vertical federated learning ======
        torch.autograd.set_detect_anomaly(True)
        # ======== Commu ===========
        # if self.args.communication_protocol in ['Vanilla', 'FedBCD_p', 'Quantization',
        #                                         'Topk'] or self.Q == 1:  # parallel FedBCD & noBCD situation
        if True:
            for q in range(self.Q):
                if q == 0:
                    # exchange info between parties
                    self.pred_transmit()
                    self.gradient_transmit()
                    # update parameters for all parties
                    for ik in range(self.k):
                        detail_extra_loss = self.parties[ik].local_backward()
                        l1_losses.append(detail_extra_loss['l1'].item())
                        l2_losses.append(detail_extra_loss['l2'].item())
                        detail_extra_loss = self.parties[ik].global_backward()
                        l2_losses.append(detail_extra_loss['l2'].item())
                        # record reconstruction loss
                        recon_losses.append(self.parties[ik].global_loss.item())
                    # record dag loss
                    dag_loss = self.server.dag_loss
                    if dag_loss is not None:
                        dag_losses.append(dag_loss.item())
                    self.init_encrypted_local_model()
                else:  # FedBCD: additional iterations without info exchange
                    # assert 1 == 2, "q > 1 not supported yet"

                    for ik in range(self.k):
                        # _pred, _pred_detach, _dag, _dag_detach = self.parties[ik].give_pred()
                        self.parties[ik].give_dag_pred()
                        detail_extra_loss = self.parties[ik].local_backward()
                        l1_losses.append(detail_extra_loss['l1'].item())
                        l2_losses.append(detail_extra_loss['l2'].item())

                        _gradient = self.parties[ik].give_gradient()
                        detail_extra_loss = self.parties[ik].global_backward()
                        l2_losses.append(detail_extra_loss['l2'].item())
                        # record reconstruction loss and dag loss
                        recon_loss = self.parties[ik].global_loss
                        recon_losses.append(recon_loss.item())
                        dag_loss = self.server.dag_loss
                        if dag_loss is not None:
                            dag_losses.append(dag_loss.item())
        # elif self.args.communication_protocol in ['CELU']:
        #     for q in range(self.Q):
        #         if (q == 0) or (parties_data.shape[0] != self.args.batch_size):
        #             # exchange info between parties
        #             self.pred_transmit()
        #             self.gradient_transmit()
        #             # update parameters for all parties
        #             for ik in range(self.k):
        #                 self.parties[ik].local_backward()
        #             self.parties[self.now_active_party].global_backward()
        #
        #             if (parties_data.shape[0] == self.args.batch_size): # available batch to cache
        #                 for ik in range(self.k):
        #                     batch = self.num_total_comms  # current batch id
        #                     self.parties[ik].cache.put(batch, [self.parties[ik].local_pred, self.parties[ik].dag_pred],
        #                                                [self.parties[ik].local_gradient, self.parties[ik].dag_gradient],
        #                                                self.num_total_comms + self.parties[ik].num_local_updates)
        #         else:
        #             for ik in range(self.k):
        #                 # Sample from cache
        #                 batch, val = self.parties[ik].cache.sample(self.parties[ik].prev_batches)
        #                 batch_cached_pred, batch_cached_grad, batch_cached_at, batch_num_update = val
        #
        #                 _pred, _pred_detach, _dag, _dag_detach = self.parties[ik].give_pred()
        #                 weight = ins_weight([_pred_detach, _dag_detach], batch_cached_pred, self.args.smi_thresh) # ins weight
        #
        #                 # Using this batch for backward
        #                 if (ik == self.now_active_party): # active
        #                     self.parties[ik].update_local_gradient(batch_cached_grad)
        #                     self.parties[ik].local_backward(weight)
        #                     self.parties[ik].global_backward()
        #                 else:
        #                     self.parties[ik].receive_gradient(batch_cached_grad)
        #                     self.parties[ik].local_backward(weight)
        #
        #                 # Mark used once for this batch + check staleness
        #                 self.parties[ik].cache.inc(batch)
        #                 if (self.num_total_comms + self.parties[ik].num_local_updates - batch_cached_at >= self.max_staleness) or \
        #                         (batch_num_update + 1 >= self.num_update_per_batch):
        #                     self.parties[ik].cache.remove(batch)
        #
        #                 self.parties[ik].prev_batches.append(batch)
        #                 self.parties[ik].prev_batches = self.parties[ik].prev_batches[1:]#[-(num_batch_per_workset - 1):]
        #                 self.parties[ik].num_local_updates += 1
        #
        # elif self.args.communication_protocol in ['FedBCD_s']: # Sequential FedBCD_s
        #     for q in range(self.Q):
        #         if q == 0:
        #             #first iteration, active party gets pred from passsive party
        #             self.pred_transmit()
        #             _gradient = self.parties[self.k-1].give_gradient()
        #             if len(_gradient)>1:
        #                 for _i in range(len(_gradient)-1):
        #                     self.communication_cost += get_size_of(_gradient[_i+1])#MB
        #             # active party: update parameters
        #             self.parties[self.k-1].local_backward()
        #             self.parties[self.k-1].global_backward()
        #         else:
        #             # active party do additional iterations without info exchange
        #             self.parties[self.k-1].give_pred()
        #             _gradient = self.parties[self.k-1].give_gradient()
        #             self.parties[self.k-1].local_backward()
        #             self.parties[self.k-1].global_backward()
        #
        #     # active party transmit grad to passive parties
        #     self.gradient_transmit()
        #
        #     # passive party do Q iterations
        #     for _q in range(self.Q):
        #         for ik in range(self.k-1):
        #             _pred, _pred_clone= self.parties[ik].give_pred()
        #             self.parties[ik].local_backward()
        else:
            assert 1 > 2, 'Communication Protocol not provided'
            # ============= Commu ===================
        batch_l1_loss = np.sum(l1_losses)
        batch_l2_loss = np.sum(l2_losses)
        batch_dag_loss = np.sum(dag_losses)
        batch_recon_loss = np.sum(recon_losses)
        batch_loss = batch_recon_loss + batch_l1_loss + batch_l2_loss
        if batch_dag_loss is not None:
            batch_loss += batch_dag_loss
        return batch_loss, {
            'l1': batch_l1_loss,
            'l2': batch_l2_loss,
            'dag': batch_dag_loss,
        }

    def test_batch(self, parties_data):
        l1_losses = []
        l2_losses = []
        dag_losses = []
        total_losses = []
        # prepare
        pred_list = []
        dag_list = []
        total_loss = 0
        # allocate data to each party
        for ik in range(self.k):
            self.parties[ik].obtain_local_data(parties_data[ik].clone())
        for owned_by_party in range(self.k):
            party_dag_list = []
            for belong_to_model in range(self.k):
                # collect local pred
                _local_pred = self.parties[owned_by_party].local_model[belong_to_model](parties_data[belong_to_model])
                pred_list.append(_local_pred)
                # collect local dag
                party_dag_list.append(self.parties[owned_by_party].local_model[belong_to_model].get_adjacency_matrix())
                # local extra loss
                extra_loss, detail_extra_loss = self.parties[owned_by_party].local_model[belong_to_model].extra_loss(
                    self.alpha,
                    self.beta,
                    return_detailed_losses=True)
                total_loss += extra_loss
                l1_losses.append(detail_extra_loss['l1'].item())
                l2_losses.append(detail_extra_loss['l2'].item())

            dag_list.append(torch.cat(party_dag_list, dim=0))
        _z = sum(pred_list)
        recon_loss_list = []
        for from_global_model in range(self.k):
            # recon loss
            test_mv, recon_loss = self.parties[from_global_model].aggregate(
                _z[:, self.acc_out_dims[from_global_model]: self.acc_out_dims[from_global_model + 1], :], test=True)
            total_loss += recon_loss
            recon_loss_list.append(recon_loss)

            # global extra loss
            extra_loss, detail_extra_loss = self.parties[from_global_model].global_model.extra_loss(self.eta,
                                                                                                    return_detailed_losses=True)
            total_loss += extra_loss
            l2_losses.append(detail_extra_loss['l2'].item())

        # dag reg
        dag_reg = self.server.aggregate(dag_list)
        if dag_reg is not None:
            dag_reg *= self.gamma
            total_loss += dag_reg
            dag_losses.append(dag_reg.detach().item())

        total_losses.append(total_loss.item())
        return np.sum(total_losses), {
            'l1': np.sum(l1_losses),
            'l2': np.sum(l2_losses),
            'dag': np.sum(dag_losses)
        }

    def train(self):

        print_every = 1

        for ik in range(self.k):
            self.parties[ik].prepare_data_loader(batch_size=self.batch_size)

        self.num_total_comms = 0
        total_time = 0.0
        flag = 0
        self.current_epoch = 0

        for i_epoch in range(self.epochs):
            self.current_epoch = i_epoch
            if self.gamma_cap is None:
                self.gamma = self.gamma_from + self.gamma_increase * i_epoch
            else:
                self.gamma = self.gamma_cap
            postfix = {'train_loss': 0.0, 'test_loss': 0.0, 'l1': 0.0, 'l2': 0.0, 'dag': 0.0, 'dag_shd': 0.0}
            i_batch = -1
            data_loader_list = [self.parties[ik].train_loader for ik in range(self.k)]

            self.current_step = 0

            for ik in range(self.k):
                for i_local_model in range(self.k):
                    self.parties[ik].local_model[i_local_model].train()
                self.parties[ik].global_model.train()

            for parties_data in zip(*data_loader_list):
                # ###### Noisy Label Attack ######
                self.parties_data = parties_data

                i_batch += 1

                # ====== train batch (start) ======
                if i_batch == 0 and i_epoch == 0:
                    self.first_epoch_state = self.save_state(True)

                enter_time = time.time()
                self.loss, train_detail_loss = self.train_batch(self.parties_data)
                exit_time = time.time()
                total_time += (exit_time - enter_time)
                if self.debug:
                    self.log(f"total time till batch {i_batch} in epoch {i_epoch} is {total_time}, batch loss:{self.loss}")
                # early stop if possible
                if flag == 0:
                    self.stopping_time = total_time
                    self.stopping_iter = self.num_total_comms
                    self.stopping_commu_cost = self.communication_cost
                    flag = 1

                if i_batch == 0 and i_epoch == 0:
                    self.first_epoch_state.update(self.save_state(False))
                # ====== train batch (end) ======

                self.current_step = self.current_step + 1

            # if self.args.apply_attack == True:
            #     if (self.args.attack_name in LABEL_INFERENCE_LIST) and i_epoch==1:
            #         log('Launch Label Inference Attack, Only train 1 epoch')
            #         break

            # self.trained_models = self.save_state(True)
            # if self.args.save_model == True:
            #     self.save_trained_models()

            # # LR decay
            # self.LR_Decay(i_epoch)
            # # LR record
            # if self.args.k == 2:
            #     LR_passive_list.append(self.parties[0].give_current_lr())
            #     LR_active_list.append(self.parties[1].give_current_lr())

            # validation
            if (i_epoch) % print_every == 0:
                self.log("validate and test")
                for ik in range(self.k):
                    for i_local_model in range(self.k):
                        self.parties[ik].local_model[i_local_model].eval()
                    self.parties[ik].global_model.eval()

                test_losses = []
                test_l1_losses = []
                test_l2_losses = []
                test_dag_losses = []

                with torch.no_grad():
                    # validation loss
                    data_loader_list = [self.parties[ik].test_loader for ik in range(self.k)]
                    for parties_data in zip(*data_loader_list):
                        self.parties_data = parties_data
                        test_loss, detail_test_loss = self.test_batch(self.parties_data)
                        test_losses.append(test_loss)
                        test_l1_losses.append(detail_test_loss['l1'])
                        test_l2_losses.append(detail_test_loss['l2'])
                        test_dag_losses.append(detail_test_loss['dag'])

                self.test_loss = np.mean(test_losses)
                test_l1_loss = np.mean(test_l1_losses)
                test_l2_loss = np.mean(test_l2_losses)
                dag_loss = np.mean(test_dag_losses)

                # now B_pred
                B_pred = self.get_adjacency_matrix().cpu().detach().numpy()
                # freeze gamma
                if self.freeze_gamma_at_dag:
                    if i_epoch > 1 and self.gamma_cap is None:
                        is_dag_freeze = nx.is_directed_acyclic_graph(
                            nx.DiGraph(B_pred > self.freeze_gamma_threshold)
                        )
                        if is_dag_freeze:
                            # If we hit a DAG, freeze the gamma value
                            self.gamma_cap = self.gamma
                    elif self.freeze_gamma_at_dag and self.gamma_cap is not None:
                        is_dag_thresh = nx.is_directed_acyclic_graph(
                            nx.DiGraph(B_pred > self.threshold)
                        )
                        if not is_dag_thresh:
                            # If we have frozen the gamma value but the graph is not a DAG, unfreeze it
                            self.gamma_cap = None
                            # early_stopping_patience_counter = 0
                # shd
                if self.args.B_true is not None:
                    dag_adjacency = self.adjacency_dag_at_threshold(B_pred,
                                                                    self.args.causal['threshold'][self.args.stage])
                    metrics_dict = compute_metrics(dag_adjacency.astype(int), self.args.B_true)
                    dag_shd = metrics_dict['shd']
                    dag_recall = metrics_dict['recall']
                    d = B_pred.shape[0]
                    adjacency = (B_pred > self.mask_threshold).astype(int) * (1 - np.eye(d, dtype=int))
                    metrics_dict = compute_metrics(adjacency, self.args.B_true)
                    shd = metrics_dict['shd']
                    recall = metrics_dict['recall']
                else:
                    dag_shd = None
                    dag_recall = None
                    shd = None
                    recall = None

                postfix['train_loss'] = self.loss
                postfix['test_loss'] = self.test_loss
                postfix['l1'] = test_l1_loss
                postfix['l2'] = test_l2_loss
                postfix['dag'] = dag_loss
                postfix['dag_shd'] = dag_shd

                self.final_epoch = i_epoch
                self.log('Epoch {} \t train_loss:{:.2f} test_loss:{:.2f} l1:{:.2f} l2:{:.2f} dag:{:.2f} '
                      'dag_shd:{} dag_recall:{:.2f} shd::{} recall:{:.2f} gamma:{}'.format(
                        i_epoch, self.loss, self.test_loss, test_l1_loss, test_l2_loss, dag_loss,
                        dag_shd, dag_recall, shd, recall, self.gamma))

        self.final_state = self.save_state(True)
        self.final_state.update(self.save_state(False))
        self.final_state.update(self.save_party_data())
        self.pickle_save(self.first_epoch_state, "first_epoch_state")
        self.pickle_save(self.final_state, "final_state")

        B_pred = self.get_adjacency_matrix().cpu().detach().numpy()
        d = B_pred.shape[0]
        new_mask = (B_pred > self.mask_threshold).astype(int) * (1 - np.eye(d, dtype=int))
        new_dag_pred = self.adjacency_dag_at_threshold(B_pred, self.threshold)

        return new_mask, new_dag_pred

    def save_state(self, BEFORE_MODEL_UPDATE=True):
        if BEFORE_MODEL_UPDATE:
            return {
                "before_model": [[copy.deepcopy(self.parties[i_party].local_model[i_model]) for i_model in range(self.args.k)] for i_party in range(self.args.k)],
                "before_global_model": [copy.deepcopy(self.parties[ik].global_model) for ik in range(self.args.k)],

            }
        else:
            return {
                "data": copy.deepcopy(self.parties_data),
                # "local_model_gradient": [copy.deepcopy(self.parties[ik].weights_grad_a) for ik in range(self.k)],
                "loss": copy.deepcopy(self.loss),
                "global_pred": self.parties[self.k - 1].global_pred,
                "after_model": [[copy.deepcopy(self.parties[i_party].local_model[i_model]) for i_model in range(self.args.k)] for i_party in range(self.args.k)],
                "after_global_model": [copy.deepcopy(self.parties[ik].global_model) for ik in range(self.args.k)],
                "B_pred": self.get_adjacency_matrix().cpu().detach().numpy(),
                "B_true": self.args.B_true
            }

    def save_party_data(self):
        return {
            "train_data": [copy.deepcopy(self.parties[ik].train_data) for ik in range(self.k)],
            "test_data": [copy.deepcopy(self.parties[ik].test_data) for ik in range(self.k)],
            "train_loader": [copy.deepcopy(self.parties[ik].train_loader) for ik in range(self.k)],
            "test_loader": [copy.deepcopy(self.parties[ik].test_loader) for ik in range(self.k)],
            "batchsize": self.args.batch_size
        }

    def get_adjacency_matrix(self):
        adjacency_submatrixs = [[] for ik in range(self.k)]
        for owned_by_party in range(self.k):
            for belong_to_model in range(self.k):
                adjacency_submatrixs[owned_by_party].append(
                    self.parties[owned_by_party].local_model[belong_to_model].get_adjacency_matrix())
        adjacency = sum([torch.cat(
            [adjacency_submatrixs[owned_by_party][belong_to_model] for belong_to_model in range(self.k)], dim=0) for
            owned_by_party in range(self.k)])
        adjacency = torch.linalg.vector_norm(adjacency, dim=2, ord=self.adj_p)
        return adjacency

    def adjacency_dag_at_threshold(self, adjacency, threshold=0.1):
        """Threshold adjacency matrix at the threshold and removes edges that makes it cyclic."""
        edges = [
            (i, j, adjacency[i, j])
            for i, j in itertools.product(range(adjacency.shape[0]), repeat=2)
        ]
        edges.sort(key=lambda x: -x[2])
        g = nx.DiGraph()
        g.add_nodes_from(range(adjacency.shape[0]))
        for e in edges:
            if e[2] < threshold:
                break
            if nx.has_path(g, e[1], e[0]):
                continue
            else:
                g.add_edge(e[0], e[1])
        return nx.to_numpy_array(g)

    @staticmethod
    def compute_metrics(B_pred_thresh, B_true):
        if B_true is not None:
            diff = B_true != B_pred_thresh
            score = diff.sum()
            shd = score - (((diff == diff.transpose()) & (diff != 0)).sum() / 2)
            recall = (B_true.astype(bool) & B_pred_thresh.astype(bool)).sum() / np.clip(
                B_true.sum(), 1, None
            )
            precision = (B_true.astype(bool) & B_pred_thresh.astype(bool)).sum() / np.clip(
                B_pred_thresh.sum(), 1, None
            )
        else:
            recall = "na"
            precision = "na"
            score = "na"
            shd = "na"

        n_edges_pred = (B_pred_thresh).sum()
        return {
            "score": score,
            "shd": shd,
            "precision": precision,
            "recall": recall,
            "n_edges_pred": n_edges_pred,
        }

    def debug_init(self, ckpt_path):
        assert self.out_dims == [10, 15], "only support dims=[10, 15]"
        device = torch.device('cuda')
        # 加载模型存档
        checkpoint = torch.load(ckpt_path)

        # layers.0._weight all zeros, i * o * h
        w = checkpoint['layers.0._weight'].to(device)
        w0, w1 = torch.split(w, [10, 15], dim=0)

        # 生成噪声
        epsl0 = torch.randn(10, 25, 10).to(device)
        epsl1 = torch.randn(15, 25, 10).to(device)

        # 计算 mass0 和 mass1
        mass0 = w0 - epsl0
        mass1 = w1 - epsl1

        # 加载参数到 parties 的 local_model
        self.parties[0].local_model[0].dispatcher.set_weight(mass0.to(device))
        self.parties[0].local_model[1].dispatcher.set_weight(epsl1.to(device))
        self.parties[1].local_model[1].dispatcher.set_weight(mass1.to(device))
        self.parties[1].local_model[0].dispatcher.set_weight(epsl0.to(device))

        # 加载 layers.0.bias, o * h

        bias0, bias1 = torch.split(checkpoint['layers.0.bias'].to(device), [10, 15], dim=0)
        self.parties[0].global_model.bias.data = bias0.to(device)
        self.parties[1].global_model.bias.data = bias1.to(device)

        # mask eye, ignore
        mask = checkpoint['layers.0.mask']

        # 加载 output_layer.weight, o * h * c
        output_weight0, output_weight1 = torch.split(checkpoint['output_layer.weight'], [10, 15], dim=0)
        self.parties[0].global_model.output_layer.weight.data = output_weight0.to(device)
        self.parties[1].global_model.output_layer.weight.data = output_weight1.to(device)

        # 加载 output_layer.bias, o * c
        output_bias0, output_bias1 = torch.split(checkpoint['output_layer.bias'], [10, 15], dim=0)
        self.parties[0].global_model.output_layer.bias.data = output_bias0.to(device)
        self.parties[1].global_model.output_layer.bias.data = output_bias1.to(device)

        # 加载 var_layer.weight, o * h * c
        var_weight0, var_weight1 = torch.split(checkpoint['var_layer.weight'], [10, 15], dim=0)
        self.parties[0].global_model.var_layer.weight.data = var_weight0.to(device)
        self.parties[1].global_model.var_layer.weight.data = var_weight1.to(device)

        # 加载 var_layer.bias, o * c
        var_bias0, var_bias1 = torch.split(checkpoint['var_layer.bias'], [10, 15], dim=0)
        self.parties[0].global_model.var_layer.bias.data = var_bias0.to(device)
        self.parties[1].global_model.var_layer.bias.data = var_bias1.to(device)

    def get_batch_z(self, parties_data, target_party):
        # prepare
        pred_list = []
        # allocate data to each party
        for ik in range(self.k):
            self.parties[ik].obtain_local_data(parties_data[ik].clone())
        for owned_by_party in range(self.k):
            for belong_to_model in range(self.k):
                if belong_to_model == target_party:
                    continue
                # collect local pred
                _local_pred = self.parties[owned_by_party].local_model[belong_to_model](parties_data[belong_to_model])
                pred_list.append(_local_pred)
        _z = sum(pred_list)
        # return _z
        return _z[:, self.acc_out_dims[target_party]: self.acc_out_dims[target_party + 1], :]

    def evaluate_unsplit(self):
        attack_configs = self.args.attack_configs
        dims = self.args.causal['dims']
        d = sum(dims)
        person_list = []
        data_loader_list = [self.parties[ik].test_loader for ik in range(self.k)]
        for attacker_party in self.args.attacker_id:
            mse = torch.nn.MSELoss()
            sample_count = 0
            sample_need = attack_configs['sample_num']
            sample_datas = []
            sample_features = []
            for parties_data in zip(*data_loader_list):
                sample_data = torch.cat([parties_data[ik] for ik in range(self.k) if ik != attacker_party], dim=1)
                bs = sample_data.shape[0]
                if sample_count + bs < sample_need:
                    sample_datas.append(sample_data)
                    sample_features.append(self.get_batch_z(parties_data, target_party=attacker_party))
                    sample_count += bs
                else:
                    still_need = sample_need - sample_count
                    sample_datas.append((sample_data[:still_need]))
                    sample_features.append(self.get_batch_z(parties_data, target_party=attacker_party)[:still_need])
                    sample_count += still_need
                    break
            sample_datas = torch.cat(sample_datas, dim=0).to(self.device)
            sample_features = torch.cat(sample_features, dim=0).to(self.device).reshape((sample_need, -1))
            sample_features_loader = torch.utils.data.DataLoader(sample_features, shuffle=False)

            aux_model = USLinear(d - dims[attacker_party], self.args.model_list[str(attacker_party)]['hidden_dim'][0] * dims[attacker_party]).to(self.device)
            for index, sample_feature in enumerate(sample_features_loader):
                input_size = [1, d - dims[attacker_party]]
                sample_data_truth = sample_datas[index]
                sample_data_pred = model_inversion_stealing(aux_model, sample_feature, input_size, sample_data_truth,
                                                            main_iters=attack_configs['main_iters'], input_iters=attack_configs['input_iters'],
                                                            model_iters=attack_configs['model_iters'], lambda_l2=attack_configs['lambda_l2'],
                                                            show_tqdm=False)

                sample_data_pred = sample_data_pred.reshape(sample_data_truth.shape)
                person = pearson_correlation(sample_data_pred, sample_data_truth)
                print("sample person:{}".format(person))
                person_list.append(person)
        return sum(person_list) / len(person_list)
