import os
import sys

from phe import paillier
from torch import nn

from evaluates.defenses.defense_functions import add_differential_privacy
from load.LoadModels import load_models_per_party_vfedcd
from models.vfedcd_model import DispatcherCipherLayer
from utils.paillier_torch import he2ss, PaillierTensor

sys.path.append(os.pardir)
import numpy as np
import torch
from party.party import Party
from dataset.party_dataset import CausalDataset
import torch.distributions as dist
import scipy

def casual_criterion(x_mv, x, mask_interventions_oh=None):
    x_mean = x_mv[:, 0, :]
    x_var = x_mv[:, 1, :]
    if mask_interventions_oh is None:
        mask_interventions_oh = torch.ones_like(x_mean)

    nll = -(
            mask_interventions_oh * dist.Normal(x_mean, x_var ** (0.5)).log_prob(x)
    ).sum()
    # we normalize by the number of samples (but ideally we shouldn't, as it mess up
    # with the L1 and L2 regularization scales)
    nll /= x.shape[0]
    return nll

class VFedParty(Party):
    def __init__(self, args, index):
        super().__init__(args, index)
        keypair = paillier.generate_paillier_keypair(n_length=256)
        self.pk, self.sk = keypair

        self.criterion = casual_criterion

        self.pred_received = []
        for _ in range(args.k):
            self.pred_received.append([])

        self.self_reconstruction_row = self.local_model[self.index].self_reconstruction_row

        self.global_pred = None
        self.global_loss = None

        self.alpha = args.causal['alpha'][args.stage]
        self.beta = args.causal['beta'][args.stage]
        self.eta = args.causal['eta'][args.stage]
        self.dims = args.dataset_split['dims']
        # for dag loss
        self.dag_pred = None
        self.dag_pred_clone = None
        self.dag_gradient = None
        self.dag_gradient_clone = None
        self.weights_grad_dag = None
        # for blind forward/backward
        self.encrypt_local_models = {}  # key:encrypted_by_party, value:encrypt_local_model
        self.plain_pred = None
        self.encrypted_pred = {}  # key:encrypted_by_party value:encrypted_pred
        self.pred_ss_epsilon = {}  # secretShareEpsilon key:encrypted_by_party value:epsilon
        self.pred_ss_mass = {}  # secretShareMass key:belong_to_model value:mass
        self.pred_ss_z, self.pred_ss_z_clone = None, None
        self.grad_ss_mass = {}  # key:belong_to_model value:mass
        self.grad_ss_fai = {}  # key:from_global_model value:fai
        self.reshaped_encrypted_gradient_z = {}  # key:from_global_model value:encrypted gradient of z

    def prepare_data(self, args, index):
        super().prepare_data(args, index)
        self.train_dst = CausalDataset(self.train_data)  # no self.train_label
        self.test_dst = CausalDataset(self.test_data)  # no self.test_label
        if self.args.need_auxiliary == 1:
            assert 1 == 2, "aux not supported for causal data preparation"
            # self.aux_loader = DataLoader(self.aux_dst, batch_size=batch_size,shuffle=True)

    def prepare_model(self, args, index):
        # prepare model and optimizer
        (
            args,
            self.local_model,
            self.local_model_optimizer,
            self.global_model,
            self.global_model_optimizer,
        ) = load_models_per_party_vfedcd(args, index)

    def give_encrypt_local_model(self, index=None):
        encrypted_models = {} # key: belong_to_model value:encrypted_model
        if index is None: # encrypt all except belong to myself
            for ik in range(self.args.k):
                if ik != self.index:
                    encrypted_models[ik] = self.local_model[ik].encrypt(self.pk, self.index)
        else:
            for ik in index:
                encrypted_models[ik] = self.local_model[ik].encrypt(self.pk, self.index)
        return encrypted_models

    def receive_encrypt_local_model(self, encrypted_by_party, encrypted_model: DispatcherCipherLayer):
        assert self.index == encrypted_model.belong_to_model, "Encrypted model belong to model {} while sent to party {}".format(
            encrypted_model.belong_to_model, self.index
        )
        assert self.index == encrypted_model.owned_by_party, "Encrypted model owned by party {} while sent to party {}".format(
            encrypted_model.owned_by_party, self.index
        )
        self.encrypt_local_models[encrypted_by_party] = encrypted_model

    def give_pred(self):
        encrypted_pred_ss_mass = {}  # key:encrypted_by_party value:encrypted_mass
        self.plain_pred = self.local_model[self.index](self.local_batch_data)

        for encrypted_by_party in range(self.args.k):
            if encrypted_by_party != self.index:
                self.encrypted_pred[encrypted_by_party] = self.encrypt_local_models[encrypted_by_party](
                    self.local_batch_data)
                self.pred_ss_epsilon[encrypted_by_party], encrypted_pred_ss_mass[encrypted_by_party] = he2ss(
                    self.encrypted_pred[encrypted_by_party])

        self.give_dag_pred()

        return self.plain_pred, self.encrypted_pred, encrypted_pred_ss_mass, self.dag_pred, self.dag_pred_clone

    def give_dag_pred(self):
        dag_list = []

        for ik in range(self.args.k):
            dag_list.append(self.local_model[ik].get_adjacency_matrix())

        self.dag_pred = torch.cat(dag_list, dim=0)
        self.dag_pred_clone = self.dag_pred.detach().clone()

    def receive_pred_ss_mass(self, belong_to_model, encrypted_mass):
        assert belong_to_model != self.index, "error"
        self.pred_ss_mass[belong_to_model] = encrypted_mass.decrypt(self.sk, device=self.args.device)

    def give_pred_ss_z(self):
        z_list = [self.plain_pred]
        z_list += list(self.pred_ss_epsilon.values())
        z_list += list(self.pred_ss_mass.values())
        self.pred_ss_z = sum(z_list)
        self.pred_ss_z_clone = self.pred_ss_z.detach().clone()
        return self.pred_ss_z, self.pred_ss_z_clone

    def receive_pred(self, pred):
        self.pred_received = pred

    def aggregate(self, pred_z, test=False):
        x_m, x_v = self.global_model(pred_z, self.local_batch_data)
        x_mv = torch.stack((x_m, x_v), dim=1)
        loss = self.criterion(x_mv, self.local_batch_data)

        return x_mv, loss

    def gradient_calculation(self, pred_z, loss):
        pred_gradient = torch.autograd.grad(loss, pred_z, retain_graph=True, create_graph=True)
        # print(f"in gradient_calculation, party#{ik}, loss={loss}, pred_gradeints={pred_gradients_list[-1]}")
        pred_gradient_clone = pred_gradient[0].detach().clone()
        # self.global_backward(pred, loss)
        return pred_gradient, pred_gradient_clone

    def give_gradient(self):
        pred_z = self.pred_received

        self.global_pred, self.global_loss = self.aggregate(pred_z)
        pred_gradient, pred_gradient_clone = self.gradient_calculation(pred_z, self.global_loss)
        bs = pred_gradient_clone.shape[0]
        reshaped_encrypted_pred_gradient_clone = PaillierTensor(
            [[self.pk.encrypt(x) for x in xs] for xs in pred_gradient_clone.reshape((bs, -1)).tolist()])

        return pred_gradient_clone, reshaped_encrypted_pred_gradient_clone

    def update_local_gradient(self, gradient):
        self.local_gradient = torch.einsum("in, noh -> ioh", self.local_batch_data.T, gradient)  # X^T * grad(Z)
        self.local_gradient = self.local_gradient.to(self.args.device)

    def receive_gradient(self, reshaped_encrypted_gradient_z, from_global_model):
        self.reshaped_encrypted_gradient_z[from_global_model] = reshaped_encrypted_gradient_z

    def give_encrypted_grad_ss_mass(self, from_global_model):
        encrypted_gradient = torch.matmul(self.local_batch_data.T,
                                          self.reshaped_encrypted_gradient_z[from_global_model])  # X^T * ||grad(Z)||
        i = encrypted_gradient.size()[0]
        o = self.dims[from_global_model]
        encrypted_gradient.reshape((i, o, -1))
        self.grad_ss_fai[from_global_model], encrypted_grad_ss_mass = \
            he2ss(encrypted_gradient)
        self.grad_ss_fai[from_global_model] = self.grad_ss_fai[from_global_model].to(self.args.device)
        return encrypted_grad_ss_mass

    def receive_grad_ss_mass(self, belong_to_model, encrypted_grad_ss_mass):
        self.grad_ss_mass[belong_to_model] = encrypted_grad_ss_mass.decrypt(self.sk, device=self.args.device)

    def receive_dag_gradient(self, dag_gradient):
        self.dag_gradient = dag_gradient

    def local_backward(self, weight=None):
        self.num_local_updates += 1  # another update
        torch.autograd.set_detect_anomaly(True)
        total_detail_extra_loss = {'l1': 0, 'l2': 0}
        for belong_to_model in range(self.args.k):
            self.local_model_optimizer[belong_to_model].zero_grad()
            # backward dag loss
            if self.dag_gradient is not None:
                self.weights_grad_dag = torch.autograd.grad(
                    self.dag_pred,
                    # if allow_unused == False, inputs=list(self.local_model[self.now_active_party].parameters())[0]
                    self.local_model[belong_to_model].parameters(),
                    grad_outputs=self.dag_gradient,
                    retain_graph=True,
                )
                for w, g in zip(self.local_model[belong_to_model].parameters(), self.weights_grad_dag):
                    if w.requires_grad:
                        w.grad = g.detach()
            # backward local extra loss
            extra_loss, detail_extra_loss = self.local_model[belong_to_model].extra_loss(self.alpha, self.beta, return_detailed_losses=True)
            total_detail_extra_loss['l1'] += detail_extra_loss['l1']
            total_detail_extra_loss['l2'] += detail_extra_loss['l2']
            extra_grad = torch.autograd.grad(
                extra_loss,
                self.local_model[belong_to_model].parameters(),
                retain_graph=True
            )
            for w, g in zip(self.local_model[belong_to_model].parameters(), extra_grad):
                if w.requires_grad:
                    if w.grad is None:
                        w.grad = g.detach()
                    else:
                        w.grad += g.detach()
            # backward reconstruction loss
            if belong_to_model == self.index:
                local_grad_list = []
                for from_global_model in range(self.args.k):
                    if from_global_model == self.index:
                        local_grad_list.append(self.local_gradient)
                    else:
                        local_grad_list.append(self.grad_ss_fai[from_global_model])
                local_grad = torch.cat(local_grad_list, dim=1)  # [n, (o1, o2....), h]
            else:
                local_grad = torch.zeros_like(self.local_model[belong_to_model].weight)
                local_grad[:, self.self_reconstruction_row[0]: self.self_reconstruction_row[1], :] = self.grad_ss_mass[belong_to_model]
            local_grad = (local_grad,)
            for w, g in zip(self.local_model[belong_to_model].parameters(), local_grad):
                if w.requires_grad:
                    w.grad += g.detach()
            # mask gradient
            for w in self.local_model[belong_to_model].parameters():
                if w.requires_grad:
                    w.grad *= torch.unsqueeze(self.local_model[belong_to_model].dispatcher.mask, dim=-1)
            self.local_model_optimizer[belong_to_model].step()

        return total_detail_extra_loss

    def global_backward(self):
        # active party with trainable global layer
        _gradients = torch.autograd.grad(self.global_loss, self.global_pred, retain_graph=True)
        _gradients_clone = _gradients[0].detach().clone()

        # if self.args.apply_mid == False and self.args.apply_trainable_layer == False:
        #     return # no need to update

        # backward reconstruction loss
        self.global_model_optimizer.zero_grad()
        parameters = []
        # trainable layer parameters
        assert self.args.apply_trainable_layer == True, "Not supported untrainable layer for causal"
        # load grads into parameters
        weights_grad_a = torch.autograd.grad(self.global_pred, self.global_model.parameters(),
                                             grad_outputs=_gradients_clone, retain_graph=True)
        for w, g in zip(self.global_model.parameters(), weights_grad_a):
            if w.requires_grad:
                w.grad = g.detach()

        # backward global extra loss
        extra_loss, detail_extra_loss = self.global_model.extra_loss(self.eta, return_detailed_losses=True)
        extra_grad = torch.autograd.grad(extra_loss, self.global_model.parameters(), retain_graph=True)
        for w, g in zip(self.global_model.parameters(), extra_grad):
            if w.requires_grad:
                w.grad += g.detach()
        # non-trainabel layer: no need to update
        self.global_model_optimizer.step()
        return detail_extra_loss

    # def global_LR_decay(self, i_epoch):
    #     if self.global_model_optimizer != None:
    #         eta_0 = self.args.main_lr
    #         eta_t = eta_0 / (np.sqrt(i_epoch + 1))
    #         for param_group in self.global_model_optimizer.param_groups:
    #             param_group['lr'] = eta_t

    # def calculate_gradient_each_class(self, global_pred, local_pred_list, test=False):
    #     # print(f"global_pred.shape={global_pred.size()}") # (batch_size, num_classes)
    #     self.gradient_each_class = [[] for _ in range(global_pred.size(1))]
    #     one_hot_label = torch.zeros(global_pred.size()).to(global_pred.device)
    #     for ic in range(global_pred.size(1)):
    #         one_hot_label *= 0.0
    #         one_hot_label[:, ic] += 1.0
    #         if self.train_index != None:  # for graph data
    #             if test == False:
    #                 loss = self.criterion(global_pred[self.train_index], one_hot_label[self.train_index])
    #             else:
    #                 loss = self.criterion(global_pred[self.test_index], one_hot_label[self.test_index])
    #         else:
    #             loss = self.criterion(global_pred, one_hot_label)
    #         for ik in range(self.args.k):
    #             self.gradient_each_class[ic].append(
    #                 torch.autograd.grad(loss, local_pred_list[ik], retain_graph=True, create_graph=True))
    #     # end of calculate_gradient_each_class, return nothing

class VFedCDServer:
    def __init__(self, args, adj):
        self.args = args
        self.stage = args.stage
        self.dag_penalty_flavor = args.causal['dag_penalty_flavor'][self.stage]
        self.power_iteration_n_steps = args.causal['power_iteration_n_steps'][self.stage]
        self.adjacency_p = args.causal['adj_p']
        self.d = sum(args.dataset_split['dims'])
        self.k = args.k
        self.dag_list = [None for _ in range(self.k)]
        self.dag_loss = None
        self.apply_global_dp = args.apply_global_dp and args.apply_server_dp
        self.adj = None
        if self.dag_penalty_flavor == "scc":
            self.power_grad = SCCPowerIteration(
                adj, self.in_dim, 1000
            )
        elif self.dag_penalty_flavor == "power_iteration":
            self.power_grad = PowerIterationGradient(
                adj,
                self.d,
                n_iter=self.power_iteration_n_steps,
            )
        elif self.dag_penalty_flavor == 'none':
            self.power_grad = None
        else:
            assert 1 == 2, "unknown dag_penalty_flavor:{}".format(self.dag_penalty_flavor)

    def dag_reg_power_grad(self, adj):
        grad, A = self.power_grad.compute_gradient(adj)
        # with torch.no_grad():
        #     grad = grad - A * (grad * A).sum() / ((A**2).sum() + 1e-6) / 2
        # grad = grad + torch.eye(self.in_dim)
        h_val = (grad.detach() * A).sum()
        return h_val

    def aggregate(self, dag_list) -> torch.Tensor:
        if self.dag_penalty_flavor == "logdet":
            # dag_reg = self.dag_reg()
            assert 1 == 2, "logdet not supported yet"
        elif self.dag_penalty_flavor in ("scc", "power_iteration"):
            dag_reg = self.dag_reg_power_grad(self.form_adjacency_matrix(dag_list))
        elif self.dag_penalty_flavor == "none":
            dag_reg = None
        else:
            assert 1 == 2, "unknown dag_penalty_flavor:{}".format(self.dag_penalty_flavor)
        return dag_reg

    def form_adjacency_matrix(self, dag_list):
        sum_result = sum(dag_list)
        if self.apply_global_dp:
            sum_result = add_differential_privacy(sum_result, self.args)
        adjacency = torch.linalg.vector_norm(sum_result, dim=2, ord=self.adjacency_p)
        self.adj = adjacency
        return adjacency

    def give_gradients(self, gamma):
        dag_list = self.dag_list
        self.dag_loss = self.aggregate(dag_list)
        if self.dag_loss is not None:
            self.dag_loss *= gamma
        pred_gradient, pred_gradient_list_clone = self.gradient_calculation(dag_list, self.dag_loss)
        # self.local_gradient = pred_gradients_list_clone[self.args.k-1] # update local gradient
        return pred_gradient_list_clone

    def gradient_calculation(self, dag_list, dag_loss):
        assert self.args.k > 0
        pred_gradient = torch.autograd.grad(dag_loss, dag_list[0], retain_graph=True, create_graph=True)
        pred_gradient_clone = pred_gradient[0].detach().clone()
        pred_gradient_list_clone = []
        for ik in range(1, self.args.k):
            delta = torch.rand_like(pred_gradient_clone) - 0.5
            pred_gradient_list_clone.append(delta)
            pred_gradient_clone -= delta
        pred_gradient_list_clone.append(pred_gradient_clone)
        return pred_gradient, pred_gradient_list_clone

    def receive_dag(self, dag_clone, ik):
        self.dag_list[ik] = dag_clone

class SCCPowerIteration(nn.Module):
    def __init__(self, init_adj_mtx, d, update_scc_freq=1000):
        super().__init__()
        self.d = d
        self.update_scc_freq = update_scc_freq

        self._dummy_param = nn.Parameter(
            torch.zeros(1), requires_grad=False
        )  # Used to track device

        self.scc_list = None
        self.update_scc(init_adj_mtx)

        self.register_buffer("v", None)
        self.register_buffer("vt", None)
        self.initialize_eigenvectors(init_adj_mtx)

        self.n_updates = 0

    @property
    def device(self):
        return self._dummy_param.device

    def initialize_eigenvectors(self, adj_mtx):
        self.v, self.vt = torch.ones(size=(2, self.d), device=self.device)
        self.v = normalize(self.v)
        self.vt = normalize(self.vt)
        return self.power_iteration(adj_mtx, 5)

    def update_scc(self, adj_mtx):
        n_components, labels = scipy.sparse.csgraph.connected_components(
            csgraph=scipy.sparse.coo_matrix(adj_mtx.cpu().detach().numpy()),
            directed=True,
            return_labels=True,
            connection="strong",
        )
        self.scc_list = []
        for i in range(n_components):
            scc = np.where(labels == i)[0]
            self.scc_list.append(scc)
        # print(len(self.scc_list))

    def power_iteration(self, adj_mtx, n_iter=5):
        matrix = adj_mtx**2
        for scc in self.scc_list:
            if len(scc) == self.d:
                sub_matrix = matrix
                v = self.v
                vt = self.vt
                for i in range(n_iter):
                    v = normalize(sub_matrix.mv(v) + 1e-6 * v.sum())
                    vt = normalize(sub_matrix.T.mv(vt) + 1e-6 * vt.sum())
                self.v = v
                self.vt = vt

            else:
                sub_matrix = matrix[scc][:, scc]
                v = self.v[scc]
                vt = self.vt[scc]
                for i in range(n_iter):
                    v = normalize(sub_matrix.mv(v) + 1e-6 * v.sum())
                    vt = normalize(sub_matrix.T.mv(vt) + 1e-6 * vt.sum())
                self.v[scc] = v
                self.vt[scc] = vt

        return matrix

    def compute_gradient(self, adj_mtx):
        if self.n_updates % self.update_scc_freq == 0:
            self.update_scc(adj_mtx)
            self.initialize_eigenvectors(adj_mtx)

        # matrix = self.power_iteration(4)
        matrix = self.initialize_eigenvectors(adj_mtx)

        gradient = torch.zeros(size=(self.d, self.d), device=self.device)
        for scc in self.scc_list:
            if len(scc) == self.d:
                v = self.v
                vt = self.vt
                gradient = torch.outer(vt, v) / torch.inner(vt, v)
            else:
                v = self.v[scc]
                vt = self.vt[scc]
                gradient[scc][:, scc] = torch.outer(vt, v) / torch.inner(vt, v)

        gradient += 100 * torch.eye(self.d, device=self.device)
        # gradient += matrix.T

        self.n_updates += 1

        return gradient, matrix


class PowerIterationGradient(nn.Module):
    def __init__(self, init_adj_mtx, d, n_iter=5):
        super().__init__()
        self.d = d
        self.n_iter = n_iter

        self.device = init_adj_mtx.device

        self.register_buffer("u", None)
        self.register_buffer("v", None)

        self.init_eigenvect(init_adj_mtx)

    def init_eigenvect(self, adj_mtx):
        self.u, self.v = torch.ones(size=(2, self.d), device=self.device)
        self.u = normalize(self.u)
        self.v = normalize(self.v)
        self.iterate(adj_mtx, self.n_iter)

    def iterate(self, adj_mtx, n=2):
        with torch.no_grad():
            A = adj_mtx + 1e-6
            for _ in range(n):
                self.one_iteration(A)

    def one_iteration(self, A):
        """One iteration of power method"""
        self.u = normalize(A.T @ self.u)
        self.v = normalize(A @ self.v)

    def compute_gradient(self, adj_mtx):
        """Gradient eigenvalue"""
        A = adj_mtx  # **2
        # fixed penalty
        self.iterate(A, self.n_iter)
        # self.init_eigenvect(adj_mtx)
        grad = self.u[:, None] @ self.v[None] / (self.u.dot(self.v) + 1e-6)
        # grad += torch.eye(self.d)
        # grad += A.T
        return grad, A

def normalize(v):
    return v / torch.linalg.vector_norm(v)
