import torch.nn as nn
import torch
from typing import Literal
import numpy as np
import torch.distributions as dist

import sys, os

from utils.paillier_torch import PaillierTensor

sys.path.append(os.pardir)

class LinearParallel(nn.Module):
    def __init__(self, in_dim, out_dim, parallel_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.parallel_dim = parallel_dim

        self.weight = nn.Parameter(torch.zeros(parallel_dim, in_dim, out_dim))
        self.bias = nn.Parameter(torch.zeros(parallel_dim, out_dim))
        self.reset_parameters()

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (batch_size, parallel_dim, in_dim)
        Returns:
            torch.Tensor: output tensor of shape (batch_size, parallel_dim, out_dim)
        """
        x = torch.einsum("npi, pio -> npo", x, self.weight) + self.bias
        return x

    @torch.no_grad()
    def reset_parameters(self):
        bound = 1.0 / self.in_dim**0.5
        nn.init.uniform_(self.weight, -bound, bound)
        nn.init.uniform_(self.bias, -bound, bound)

    def __repr__(self):
        return f"LinearParallel(in_dim={self.in_dim}, out_dim={self.out_dim}, parallel_dim={self.parallel_dim})"

class DispatcherCipherLayer(nn.Module):
    def __init__(
            self,
            owned_by_party,
            belong_to_model,
            encrypted_by_party,
            in_dim,
            out_dim,
            hidden_dim,
            encrypted_weight: PaillierTensor,
    ):
        super().__init__()
        self.owned_by_party = owned_by_party
        self.belong_to_model = belong_to_model
        self.encrypted_by_party = encrypted_by_party
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.embed_dim = out_dim * hidden_dim

        self._weight = encrypted_weight

    @property
    def weight(self):
        return self._weight  # no mask

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (batch_size, in_dim)
        Returns:
            torch.Tensor: output tensor of shape (batch_size, out_dim, hidden_dim)
        """
        # ni * io = ne; e:embed_dim = o:out_dim * h:hidden_dim
        x = torch.matmul(x, self.weight)
        x.reshape((-1, self.out_dim, self.hidden_dim))
        return x

    def __repr__(self):
        return (
            f"DispatcherLayer("
            f"in_dim={self.in_dim}, "
            f"out_dim={self.out_dim}, "
            f"hidden_dim={self.hidden_dim}, "
            f"adjacency_p={self.adjacency_p}"
            f")"
        )

class DispatcherPlainLayer(nn.Module):
    def __init__(
        self,
        owned_by_party,
        belong_to_model,
        in_dim, # i
        out_dim, # o
        hidden_dim, # h
        adjacency_p=2.0,
        mask=None,
    ):
        super().__init__()
        self.owned_by_party = owned_by_party
        self.belong_to_model = belong_to_model
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dim = hidden_dim
        self.embed_dim = out_dim * hidden_dim
        self.adjacency_p = adjacency_p

        if mask is not None:
            self.register_buffer("mask", torch.tensor(mask).float())
        else:
            self.register_buffer("mask", torch.ones((in_dim, out_dim)))

        self._weight = nn.Parameter(torch.zeros(in_dim, out_dim, hidden_dim))
        self.reset_parameters_bounded_eigenvalues()

    @property
    def weight(self):
        if self.mask is not None:
            return self._weight * self.mask[:, :, None]
        return self._weight

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): input tensor of shape (batch_size, in_dim)
        Returns:
            torch.Tensor: output tensor of shape (batch_size, out_dim, hidden_dim)
        """
        # ni * io = ne; e:embed_dim = o:out_dim * h:hidden_dim
        x = torch.einsum("ni, ioh -> noh", x, self.weight)
        return x

    @torch.no_grad()
    def reset_parameters(self):
        self.reset_parameters_bounded_eigenvalues()

    @torch.no_grad()
    def reset_parameters_bounded_eigenvalues(self, scale=1.0):
        init_bound = scale / self.in_dim / self.hidden_dim ** (1.0 / self.adjacency_p)
        nn.init.uniform_(self.weight, -init_bound, init_bound)


    def get_adjacency_matrix(self):
        assert self.adjacency_p == 1, "adjacency_p({}) should equal 1".format(self.adjacency_p)
        return self.weight
        # return torch.linalg.vector_norm(self.weight, dim=2, ord=self.adjacency_p)

    def encrypt(self, pk, encrypted_by_party) -> DispatcherCipherLayer:
        encrypted_weight = PaillierTensor([[pk.encrypt(x) for x in xs] for xs in self.weight.view(self.in_dim, self.out_dim * self.hidden_dim).tolist()])
        return DispatcherCipherLayer(owned_by_party=self.belong_to_model,
                                     belong_to_model=self.belong_to_model,
                                     encrypted_by_party=encrypted_by_party,
                                     encrypted_weight=encrypted_weight,
                                     in_dim=self.in_dim, out_dim=self.out_dim, hidden_dim=self.hidden_dim)

    def set_weight(self, new_weight):
        self._weight.data = new_weight

    def __repr__(self):
        return (
            f"DispatcherLayer("
            f"in_dim={self.in_dim}, "
            f"out_dim={self.out_dim}, "
            f"hidden_dim={self.hidden_dim}, "
            f"adjacency_p={self.adjacency_p}"
            f")"
        )

class VFedCDEncoder(nn.Module):
    # ni -> noh
    def __init__(
        self,
        owned_by_party,
        belong_to_model,
        in_dim,  # i
        out_dim,  # o
        hidden_dims,  # h
        self_reconstruction_row,
        adjacency_p: float = 2.0,
        mask=None,
        dag_penalty_flavor: Literal["scc", "power_iteration", "logdet", "none"] = "none",
    ):
        super().__init__()
        assert self_reconstruction_row[1] - self_reconstruction_row[0] == in_dim, "self_reconstruction_col[1]({}) - self_reconstruction_col[0]({}) " \
                                                                                  "should equal in_dim({})".format(self_reconstruction_row[1],
                                                                                                                   self_reconstruction_row[0], in_dim)
        assert out_dim > in_dim, "should hold that out_dim({}) >= in_dim({})".format(out_dim, in_dim)
        self.owned_by_party = owned_by_party
        self.belong_to_model = belong_to_model
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_dims = hidden_dims
        self.adjacency_p = adjacency_p
        self.self_reconstruction_row = self_reconstruction_row
        self.dag_penalty_flavor = dag_penalty_flavor

        self_recon_mask = np.ones((self.in_dim, self.out_dim))
        for i in range(self.in_dim):
            self_recon_mask[i, i + self.self_reconstruction_row[0]] = 0
        if mask is None:
            mask = self_recon_mask
        else:
            if mask is not None:
                mask = (
                    mask.astype(bool) & self_recon_mask.astype(bool)
                ).astype(int)

        self.dispatcher = DispatcherPlainLayer(
                owned_by_party=self.owned_by_party,
                belong_to_model=self.belong_to_model,
                in_dim=self.in_dim,
                out_dim=self.out_dim,
                hidden_dim=hidden_dims[0],
                adjacency_p=self.adjacency_p,
                mask=mask,
            )

        self.reset_parameters()

    @property
    def weight(self):
        return self.dispatcher.weight  # no mask

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, x):
        x = self.dispatcher(x)
        return x

    def get_adjacency_matrix(self):
        return self.dispatcher.get_adjacency_matrix()

    def update_mask(self, mask):
        mask = (mask.astype(bool) & (1 - np.eye(self.in_dim)).astype(bool)).astype(int)
        self.dispatcher.mask = torch.tensor(mask).to(self.device)

    @torch.no_grad()
    def reset_parameters(self):
        self.dispatcher.reset_parameters()

    def l1_reg_dispatcher(self):
        return torch.sum(torch.abs(self.dispatcher.weight))

    def l2_reg_all_weights(self):
        return sum(
            [
                torch.sum(p**2)
                for p_name, p in self.named_parameters()
                if p.requires_grad
            ]
        )

    def extra_loss(
        self,
        alpha=1.0,
        beta=1.0,
        return_detailed_losses=False,
    ):
        alpha_l1_reg = alpha * self.l1_reg_dispatcher()  # * n_obs_norm
        beta_l2_reg = beta * self.l2_reg_all_weights()  # * n_obs_norm
        total_loss = alpha_l1_reg + beta_l2_reg

        if return_detailed_losses:
            return total_loss, {
                "l1": alpha_l1_reg.detach(),
                "l2": beta_l2_reg.detach(),
            }
        else:
            return total_loss

    def encrypt(self, pk, encrypted_by_party):
        return self.dispatcher.encrypt(pk, encrypted_by_party)

class VFedCDDecoder(nn.Module):
    # noh * ohc -> noc
    def __init__(
        self,
        hidden_dims,  # h
        out_dim,  # o
        model_variance_flavor: Literal["unit", "nn", "parameter"] = "nn",
        activation=nn.Sigmoid(),
        init_bound=None,
        k=1
    ):
        super().__init__()
        self.out_dim = out_dim
        self.hidden_dims = hidden_dims
        self.model_variance_flavor = model_variance_flavor
        self.init_bound = init_bound
        self.k = k
        self.bias = nn.Parameter(torch.zeros(out_dim, self.hidden_dims[0]))
        if init_bound is not None:
            self.reset_parameters()
        self.activation = activation

        self.layers = nn.ModuleList()
        if len(self.hidden_dims) > 1:
            dims = self.hidden_dims
            for i in range(len(dims) - 1):
                self.layers.append(LinearParallel(dims[i], dims[i + 1], self.out_dim))

        if (
            self.model_variance_flavor == "nn"
            or self.model_variance_flavor == "parameter"
        ):
            self.var_activation = nn.Softplus()

        dims = self.hidden_dims
        self.output_layer = LinearParallel(dims[-1], 1, self.out_dim)
        if self.model_variance_flavor == "nn":
            self.var_layer = LinearParallel(dims[-1], 1, self.out_dim)

        if self.model_variance_flavor == "parameter":
            self.gene_vars = nn.Parameter(torch.zeros(self.out_dim))

    @torch.no_grad()
    def reset_parameters(self):
        nn.init.uniform_(self.bias, -self.init_bound, self.init_bound)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, x, batch_data):
        x = x + self.bias
        x = self.activation(x)

        for layer in self.layers:
            x = self.activation(layer(x))

        x_m = self.output_layer(x).squeeze(2)

        if self.model_variance_flavor == "nn":
            x_v = self.var_activation(self.var_layer(x)).squeeze(2)
        elif self.model_variance_flavor == "parameter":
            x_v = torch.broadcast_to(
                self.var_activation(self.gene_vars).unsqueeze(0), x_m.shape
            )
        elif self.model_variance_flavor == "unit":
            x_v = torch.ones_like(x_m)
        else:
            raise NotImplementedError

        return x_m, x_v

    def reconstruction_loss(self, x, mask_interventions_oh=None):
        x_mean, x_var = self(x)

        if mask_interventions_oh is None:
            mask_interventions_oh = torch.ones_like(x_mean)

        nll = -(
            mask_interventions_oh * dist.Normal(x_mean, x_var ** (0.5)).log_prob(x)
        ).sum()
        # we normalize by the number of samples (but ideally we shouldn't, as it mess up
        # with the L1 and L2 regularization scales)
        nll /= x.shape[0]
        return nll

    def l2_reg_all_weights(self):
        return sum(
            [
                torch.sum(p**2)
                for p_name, p in self.named_parameters()
                if p.requires_grad and (p_name != "layers.0.gumbel_adjacency.log_alpha")
            ]
        )

    def extra_loss(
        self,
        beta=1.0,
        return_detailed_losses=False,
    ):
        beta_l2_reg = beta * self.l2_reg_all_weights()  # * n_obs_norm

        total_loss = beta_l2_reg

        if return_detailed_losses:
            return total_loss, {
                "l2": beta_l2_reg.detach(),
            }
        else:
            return total_loss

# class AutoEncoderLayers(nn.Module):
#     def __init__(
#         self,
#         in_dim,
#         hidden_dims,
#         activation=nn.ReLU(),
#         model_variance_flavor: Literal["unit", "nn", "parameter"] = "unit",
#         shared_layers: bool = True,
#         adjacency_p: float = 2.0,
#         mask=None,
#         dag_penalty_flavor: Literal["scc", "power_iteration", "logdet", "none"] = "scc",
#         use_gumbel=False,
#         power_iteration_n_steps=5,
#     ):
#         super().__init__()
#         self.in_dim = in_dim
#         self.hidden_dims = hidden_dims
#         self.activation = activation
#         self.model_variance_flavor = model_variance_flavor
#         self.shared_layers = shared_layers
#         self.adjacency_p = adjacency_p
#         self.use_gumbel = use_gumbel
#
#         if (
#             self.model_variance_flavor == "nn"
#             or self.model_variance_flavor == "parameter"
#         ):
#             self.var_activation = nn.Softplus()
#
#         self.dag_penalty_flavor = dag_penalty_flavor
#         if dag_penalty_flavor == "none":
#             # Need to mask out identity to prevent learning self-loops
#             if mask is not None:
#                 mask = (
#                     mask.astype(bool) & (1 - np.eye(self.in_dim)).astype(bool)
#                 ).astype(int)
#             else:
#                 mask = 1 - np.eye(self.in_dim)
#
#         self.layers = nn.ModuleList()
#         self.layers.append(
#             DispatcherLayer(
#                 self.in_dim,
#                 self.in_dim,
#                 hidden_dims[0],
#                 adjacency_p=self.adjacency_p,
#                 mask=mask,
#                 use_gumbel=self.use_gumbel,
#             )
#         )
#
#         if dag_penalty_flavor == "scc":
#             self.power_grad = SCCPowerIteration(
#                 self.get_adjacency_matrix(), self.in_dim, 1000
#             )
#         elif dag_penalty_flavor == "power_iteration":
#             self.power_grad = PowerIterationGradient(
#                 self.get_adjacency_matrix(),
#                 self.in_dim,
#                 n_iter=power_iteration_n_steps,
#             )
#
#         self.identity = torch.eye(self.in_dim)
#
#         # if layers are shared, use regular dense layers
#         # else use parallel layers
#         if shared_layers:
#             dims = self.hidden_dims
#             for i in range(len(dims) - 1):
#                 self.layers.append(nn.Linear(dims[i], dims[i + 1]))
#             self.output_layer = nn.Linear(dims[-1], 1)
#             if self.model_variance_flavor == "nn":
#                 self.var_layer = nn.Linear(dims[-1], 1)
#         else:
#             dims = self.hidden_dims
#             for i in range(len(dims) - 1):
#                 self.layers.append(LinearParallel(dims[i], dims[i + 1], self.in_dim))
#             self.output_layer = LinearParallel(dims[-1], 1, self.in_dim)
#             if self.model_variance_flavor == "nn":
#                 self.var_layer = LinearParallel(dims[-1], 1, self.in_dim)
#
#         if self.model_variance_flavor == "parameter":
#             self.gene_vars = nn.Parameter(torch.zeros(self.in_dim))
#
#         self.reset_parameters()
#
#     @property
#     def device(self):
#         return next(self.parameters()).device
#
#     def forward(self, x):
#         """
#         Args:
#             x (torch.Tensor): input tensor of shape (batch_size, in_dim)
#         Returns:
#             torch.Tensor: output tensor of shape (batch_size, out_dim)
#         """
#         for layer in self.layers:
#             x = self.activation(layer(x))
#
#         x_m = self.output_layer(x).squeeze(2)
#
#         if self.model_variance_flavor == "nn":
#             x_v = self.var_activation(self.var_layer(x)).squeeze(2)
#         elif self.model_variance_flavor == "parameter":
#             x_v = torch.broadcast_to(
#                 self.var_activation(self.gene_vars).unsqueeze(0), x_m.shape
#             )
#         elif self.model_variance_flavor == "unit":
#             x_v = torch.ones_like(x_m)
#         else:
#             raise NotImplementedError
#
#         return x_m, x_v
#
#     def get_adjacency_matrix(self):
#         return self.layers[0].get_adjacency_matrix()
#
#     def update_mask(self, mask):
#         mask = (mask.astype(bool) & (1 - np.eye(self.in_dim)).astype(bool)).astype(int)
#         self.layers[0].mask = torch.tensor(mask).to(self.device)
#
#     @torch.no_grad()
#     def reset_parameters(self):
#         for layer in self.layers:
#             layer.reset_parameters()
#
#     def reconstruction_loss(self, x, mask_interventions_oh=None):
#         x_mean, x_var = self(x)
#
#         if mask_interventions_oh is None:
#             mask_interventions_oh = torch.ones_like(x_mean)
#
#         nll = -(
#             mask_interventions_oh * dist.Normal(x_mean, x_var ** (0.5)).log_prob(x)
#         ).sum()
#         # we normalize by the number of samples (but ideally we shouldn't, as it mess up
#         # with the L1 and L2 regularization scales)
#         nll /= x.shape[0]
#         return nll
#
#     def l1_reg_dispatcher(self):
#         # maybe change to abs of the collapsed weights (sum over hidden dim)
#         if self.use_gumbel:
#             return torch.sum(
#                 self.layers[0].gumbel_adjacency.get_proba()
#                 * self.layers[0].adjacency_mask
#             )
#         return torch.sum(torch.abs(self.layers[0].weight))
#
#     def l2_reg_all_weights(self):
#         return sum(
#             [
#                 torch.sum(p**2)
#                 for p_name, p in self.named_parameters()
#                 if p.requires_grad and (p_name != "layers.0.gumbel_adjacency.log_alpha")
#             ]
#         )
#
#     def dag_reg(self):
#         A = self.get_adjacency_matrix() ** 2
#         h = -torch.slogdet(self.identity - A)[1]
#         return h
#
#     def dag_reg_power_grad(self):
#         grad, A = self.power_grad.compute_gradient(self.get_adjacency_matrix())
#         # with torch.no_grad():
#         #     grad = grad - A * (grad * A).sum() / ((A**2).sum() + 1e-6) / 2
#         # grad = grad + torch.eye(self.in_dim)
#         h_val = (grad.detach() * A).sum()
#         return h_val
#
#     def loss(
#         self,
#         x,
#         alpha=1.0,
#         beta=1.0,
#         gamma=1.0,
#         n_observations=None,
#         mask_interventions_oh=None,
#         return_detailed_losses=False,
#     ):
#         nll = self.reconstruction_loss(x, mask_interventions_oh=mask_interventions_oh)
#         l1_reg = alpha * self.l1_reg_dispatcher()  # * n_obs_norm
#         l2_reg = beta * self.l2_reg_all_weights()  # * n_obs_norm
#         # mu = 1 / gamma
#         if self.dag_penalty_flavor == "logdet":
#             dag_reg = self.dag_reg()
#         elif self.dag_penalty_flavor in ("scc", "power_iteration"):
#             dag_reg = self.dag_reg_power_grad()
#         elif self.dag_penalty_flavor == "none":
#             dag_reg = 0.0
#         # dag_reg = dag_reg.to(self.device)
#
#         total_loss = nll + l1_reg + l2_reg + gamma * dag_reg
#
#         if return_detailed_losses:
#             return total_loss, {
#                 "nll": nll.detach(),
#                 "l1": l1_reg.detach(),
#                 "l2": l2_reg.detach(),
#                 "dag": dag_reg.detach() if type(dag_reg) != float else torch.zeros(1),
#             }
#         else:
#             return total_loss
