import math
import torch
from einops import rearrange
from torch import nn, Tensor, optim
from torch.nn.functional import binary_cross_entropy
from torch.types import Device

from csiva.model.decoder import CausalDecoder
from csiva.model.encoder import CausalEncoder
import lightning.pytorch as pl

from csiva.utils.metrics import average_shd, average_out_degree, num_dags


class CausalInducer(pl.LightningModule):
    def __init__(
        self,
        num_nodes,
        d_model,
        dim_feedforward,
        depth_encoder,
        depth_decoder,
        num_heads,
        p_dropout=0.,
        threshold: float = 0.5,
        lr: float = 1e-4,
        eps_layer_norm: float = 0.0005,
        encoder_layer_type: str = "custom",
        encoder_summary_type: str = "sdp",
        rff_depth: int = 1
    ):
        """CSIVA architecture.

        A single datapoint for the supervised training of the model is an entire dataset,
        of shape (num_samples, num_nodes). 

        Parameters
        ----------
        num_nodes : int
            Number of input nodes in the causal graph.
        d_model : int
            Dimension of the key, query, value input embeddings.
        dim_feedforward: int
            Hidden dimension of the MLP after MHSA in encoder and decoder layer.
        depth_encoder : int
            Number of staked layers in the encoder.
        depth_decoder : int
            Number of staked layers in the decoder.
        p_dropout : float, default 0.0
            Dropout probability in multi-head attention in the alternate attention blocks.
        threshold: float, default 0.5
            Thresholding of the output adjacency matrix. Values above threshold are mapped
            to 1, values below threshold are mapped to 0.
        lr: float, default 1e-4
            The learning rate value.
        eps_layer_norm : float
            LayerNorm epsilon in multi-head attention in the alternate attention blocks.
            Required for numerical stability.
        encoder_layer_type: str, default "custom"
            Specify which class to use for the encoder layer. Use "custom" for csiva.model.MHSA,
            or "torch" for torch.nn.TransformerEncoderLayer.
        encoder_summary_type: str, default "sdp"
            Specify how to compute the encoder summary. Use "sdp" for 
            torch.nn.functionals.scaled_dot_product, or "mhsa" for torch.nn.MulitHeadSelfAttention.
        rff_depth: int, default 1
            Number of feed-forward layers in the MLP of the encoder layer in alternate attention.
        """
        super().__init__()
        
        self.save_hyperparameters()

        self.start_token = -1
        self.threshold = threshold
        self.num_nodes = num_nodes
        self.lr = lr
        self.encoder = CausalEncoder(num_nodes=num_nodes, d_model=d_model, dim_feedforward=dim_feedforward, 
                                     num_encoder_layers=depth_encoder, num_heads=num_heads, p_dropout=p_dropout,
                                     eps_layer_norm=eps_layer_norm, encoder_layer_type=encoder_layer_type,
                                     encoder_summary_type=encoder_summary_type, rff_depth=rff_depth)
        self.decoder = CausalDecoder(
            num_nodes**2, d_model, depth_decoder, num_heads, dim_feedforward, p_dropout
        )

        # Predict children and parents directly from encoder summary for training
        self.aux_head = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.GELU(),
            nn.Dropout(p_dropout),
            nn.Linear(dim_feedforward, 2 * num_nodes),
            nn.Sigmoid()
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        encoder_summary = self.encoder(x)
        aux_out = self.aux_head(encoder_summary)
        mask = self._create_look_ahead_mask(x.device)
        out = self.decoder(self._shift_target(y), encoder_summary, mask)
        loss = self.get_loss(out, aux_out, y)

        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        self.log("train_shd", self.get_avg_shd(out, y))
        self.log("train_degree", self.get_avg_degree(out))
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=0.00)
        return optimizer

    def validation_step(self, batch, batch_idx):
        x, y = batch
        encoder_summary = self.encoder(x)
        aux_out = self.aux_head(encoder_summary)
        mask = self._create_look_ahead_mask(x.device)
        out = self.decoder(self._shift_target(y), encoder_summary, mask)
        loss = self.get_loss(out, aux_out, y)

        self.log("val_loss", loss)
        self.log("val_cross_entropy", self.get_cross_entropy(out, y))
        self.log("val_aux_loss", self.get_aux_loss(aux_out, y))
        self.log("val_shd", self.get_avg_shd(out, y))
        self.log("val_degree", self.get_avg_degree(out))
        self.log("val_num_dags", self.get_num_dags(out))

    def test_step(self, batch, batch_idx):
        x, y = batch
        out = self.predict_autoregressive(x)

        self.log("test_cross_entropy", self.get_cross_entropy(out, y))
        self.log("test_shd", self.get_avg_shd(out, y))
        self.log("test_degree", self.get_avg_degree(out))
        self.log("test_num_dags", self.get_num_dags(out))

    @torch.no_grad()
    def predict_autoregressive(self, data: Tensor) -> Tensor:
        seq_length = self.num_nodes ** 2
        output = torch.zeros(size=(data.shape[0], seq_length + 1), device=data.device)
        output[:, 0] = self.start_token
        mask = self._create_look_ahead_mask(data.device)

        # Actual inference
        encoder_summary = self.encoder(data)
        for i in range(seq_length):
            logit_output = self.decoder(output[:, :-1], encoder_summary, mask)
            output[:, i + 1] = (logit_output > self.threshold)[:, i]
        return output[:, 1:]

    def _shift_target(self, target: Tensor) -> Tensor:
        # Shift target for training. So model learns to predict token n from token n-1.
        shifted_target = torch.clone(target)
        shifted_target[:, 1:] = target[:, :-1]
        shifted_target[:, 0] = self.start_token  # indicate start of sequence
        return shifted_target

    def get_loss(self, output: Tensor, aux_output: Tensor, target: Tensor) -> Tensor:
        cross_entropy = self.get_cross_entropy(output, target)
        aux = self.get_aux_loss(aux_output, target)
        return cross_entropy + aux

    def get_cross_entropy(self, output: Tensor, target: Tensor) -> Tensor:
        return binary_cross_entropy(output, target)

    def get_aux_loss(self, aux_output: Tensor, target: Tensor) -> Tensor:
        # Target has shape b x d^2. Need b x d x 2*d with first d entries per node showing children and second d showing
        # parents.
        adj = rearrange(target, 'b (d f) -> b d f', d=self.num_nodes)
        aux_target = torch.concat([adj, adj.transpose(-1, -2)], dim=2)
        # Flatten output and target for cross-entropy
        aux_outputs = rearrange(aux_output, 'b d f -> b (d f)')
        aux_target = rearrange(aux_target, 'b d f -> b (d f)')
        return binary_cross_entropy(aux_outputs, aux_target)

    def get_avg_shd(self, output: Tensor, target: Tensor) -> float:
        thresholded_out = (output > self.threshold).to(output.dtype)
        batch_size = output.shape[0]
        return average_shd(
            thresholded_out.reshape(batch_size, self.num_nodes, self.num_nodes),
            target.reshape(batch_size, self.num_nodes, self.num_nodes),
            batch_size
        )

    def get_avg_degree(self, output: Tensor) -> float:
        thresholded_out = (output > self.threshold).to(output.dtype)
        return average_out_degree(thresholded_out, output.shape[0], self.num_nodes)

    def get_num_dags(self, output: Tensor) -> float:
        thresholded_out = (output > self.threshold).to(output.dtype)
        return num_dags(thresholded_out, output.shape[0], self.num_nodes)

    def _create_look_ahead_mask(self, device: Device) -> Tensor:
        seq_length = self.num_nodes ** 2
        attn_shape = (seq_length, seq_length)
        mask = (1 - torch.tril(torch.ones(attn_shape, device=device))) * -1e9
        return mask


