# This is required in python 3 to allow return types of the same class.
from __future__ import annotations

import numpy as np
import torch
from typing import Dict, Optional, Any, Callable, List
from torch import nn
from ...datasets.dataset import Dataset
from torch.utils.data import DataLoader, TensorDataset
from ...utils.data_mask_utils import to_tensors
from ...utils.torch_utils import generate_fully_connected
from ...utils.training_objectives import get_input_and_scoring_masks

# Right now vmap used for very simple stuff:
#   1. Batching a matrix-vector product

from .variational_distributions import VarDistA_Simple, VarDistA_ENCO

import torch.distributions as td

from ...models.torch_model import TorchModel
from ...models.imodel import IModelForCausalInference, IModelForImputation
from ...datasets.variables import Variables

from .generation_functions import ContractiveInvertibleGNN

from typing import Tuple


class FlowVICause(TorchModel, IModelForCausalInference, IModelForImputation):
    """
    Flow-based VICause model, which does causal discovery using a contractive and
    invertible GNN. The adjacency is a random variable over which we do inference.
    """

    def __init__(
        self,
        model_id: str,
        variables: Variables,
        save_dir: str,
        device: torch.device,
        lambda_dag: float = 1.0,
        lambda_sparse: float = 1.0,
        tau_gumbel: float = 1.0,
        train_base: bool = False,
        var_dist_A_mode: str = "simple",
        layers_imputer: List[int] = None,
        mode_f_sem: str = "linear",
        mode_adjacency: str = "learn",
    ):
        """
        Args:
            model_id: Unique model ID for referencing this model instance.
            variables: Information about variables/features used by this model.
            save_dir: Location to save any information about this model, including training data.
            device: Device to load model to.
            lambda_dag: Coefficient for the prior term that enforces DAG.
            lambda_sparse: Coefficient for the prior term that enforces sparsity.
            tau_gumbel: Temperature for the gumbel softmax trick.
            train_base: Whether to train the parameters of the abse distribution.
            var_dist_A_mode: Variational distribution for adjacency matrix. Admits {"simple", "enco"}. "simple"
                             parameterizes each edge (including orientation) separately. "enco" parameterizes
                             existence of an edge and orientation separately.
            layers_imputer: Number and size of hidden layers for imputer NN for variational distribution.
            mode_f_sem: Mode used for function. Admits {"linear", "lrelu", "gnn_i"}. The first one
                        is a linear function, the second leaky relu. The third one described in pdf.
            mode_adjacency: In {"upper", "lower", "learn"}. If "learn", do our method as usual. If
                            "upper"/"lower" fix adjacency matrix to strictly upper/lower triangular.
        """
        super().__init__(model_id, variables, save_dir, device)

        self._device = device
        self.lambda_dag = lambda_dag
        self.lambda_sparse = lambda_sparse
        self.input_dim = variables.num_processed_cols
        self.ICGNN = ContractiveInvertibleGNN(self.input_dim, device, mode_f_sem)
        self.mean_base, self.logscale_base = self._initialize_params_base_dist(train=train_base)
        layers_imputer = layers_imputer or [max(80, 2 * self.input_dim)] * 2
        self.imputer_network = generate_fully_connected(
            input_dim=2 * self.input_dim,
            output_dim=2 * self.input_dim,
            hidden_dims=layers_imputer,
            non_linearity=nn.LeakyReLU,
            activation=nn.Identity,
            device=self._device,
        )

        self.mode_adjacency = mode_adjacency
        if var_dist_A_mode == "simple":
            self.var_dist_A = VarDistA_Simple(device=device, input_dim=self.input_dim, tau_gumbel=tau_gumbel)
        elif var_dist_A_mode == "enco":
            self.var_dist_A = VarDistA_ENCO(device=device, input_dim=self.input_dim, tau_gumbel=tau_gumbel)
        else:
            raise NotImplementedError()

    @classmethod
    def name(cls) -> str:
        return "flow_vicause"

    def get_adj_matrix(self, round: bool = True, samples: int = 100) -> np.ndarray:
        """
        Returns the adjacency matrix (or several) as a numpy array.
        """
        if samples == 1:
            adj_matrix = self.var_dist_A.get_adj_matrix(round=round)
        else:
            A_samples = [self.var_dist_A.sample_A() for i in range(samples)]
            adj_matrix = torch.stack(A_samples)
        # Here we have the cast to np.float64 because the original type
        # np.float32 has some issues with json, when saving causality results
        # to a file.
        return adj_matrix.detach().cpu().numpy().astype(np.float64)

    def dagness_factor(self, A: torch.Tensor) -> torch.Tensor:
        """
        Computes the dag penalty for matrix A as trace(expm(A)) - dim.

        Args:
            A: Binary adjacency matrix, size (input_dim, input_dim).
        """
        return torch.trace(torch.matrix_exp(A)) - self.input_dim

    def _log_prior_A(self, A: torch.Tensor) -> torch.Tensor:
        """
        Computes the prior for adjacency matrix A, which consitst on a term encouraging DAGness
        and another encouraging sparsity (see https://arxiv.org/pdf/2106.07635.pdf).

        Args:
            A: Adjancency matrix of shape (input_dim, input_dim), binary.

        Returns:
            Log probability of A for prior distribution, a number.
        """
        sparse_term = -self.lambda_sparse * A.sum()  # This is ||vec(A)||_1
        return sparse_term

    def _ELBO_terms(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Computes all terms involved in the ELBO.

        Args:
            X: Batched samples from the dataset, size (batch_size, input_dim).

        Returns:
            Tuple (penalty_dag, log_p_A, log_p_base, log_q_A)
        """
        # Get adjacency matrix with weights
        if self.mode_adjacency == "learn":
            A_sample = self.var_dist_A.sample_A()
            factor_q = 1.0
        elif self.mode_adjacency == "upper":
            A_sample = torch.triu(torch.ones(self.input_dim, self.input_dim), diagonal=1).to(self._device)
            factor_q = 0.0
        elif self.mode_adjacency == "lower":
            A_sample = torch.tril(torch.ones(self.input_dim, self.input_dim), diagonal=-1).to(self._device)
            factor_q = 0.0
        else:
            raise NotImplementedError("Adjacency mode %s not implemented" % self.mode_adjacency)
        W_adj = A_sample * self.ICGNN.get_weighted_adjacency()
        Z = self.ICGNN.invert_GNN(X, W_adj)
        log_p_A = self._log_prior_A(A_sample)  # A number
        penalty_dag = self.dagness_factor(A_sample)  # A number
        log_p_base = self._log_prob_base(Z)  # (B)
        log_q_A = -self.var_dist_A.entropy()  # A number
        return penalty_dag, log_p_A, log_p_base, log_q_A * factor_q

    def compute_loss(
        self,
        step: int,
        x: torch.Tensor,
        mask_train_batch: torch.Tensor,
        input_mask: torch.Tensor,
        num_samples: int,
        tracker: Dict,
        train_config_dict: Dict[str, Any] = {},
        alpha: float = None,
        rho: float = None,
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Computes loss and updates trackers of different terms. mask_train_batch is the mask indicating
        which values are missing in the dataset, and input_mask indicates which additional values
        are artificially masked.
        """
        # Get mean for imputation with artificially masked data, "regularizer" for imputation network
        _, _, mean_rec = self.impute_and_compute_entropy(x, input_mask)
        # Compute reconstruction loss on artificially dropped data
        scoring_mask = mask_train_batch
        reconstruction_term = (mean_rec - x) * (mean_rec - x) * scoring_mask  # Shape (batch_size, input_dim)
        reconstruction_term = (
            reconstruction_term.sum(1) * train_config_dict["reconstruction_loss_factor"]
        )  # Shape (batch_size)
        # Fill in missing data with amortized variational approximation for causality (without artificially masked data)
        x_fill, entropy_filled_term, _ = self.impute_and_compute_entropy(x, mask_train_batch)
        # Compute remaining terms
        penalty_dag, log_p_A_sparse, log_p_base, log_q_A = self._ELBO_terms(x_fill)
        log_p_term = log_p_base
        log_p_A_term = log_p_A_sparse / num_samples
        log_q_A_term = log_q_A / num_samples

        if train_config_dict["opt_mode"] == "simple":
            penalty_dag_term = penalty_dag * self.lambda_dag / num_samples
        elif train_config_dict["opt_mode"] == "auglag":
            penalty_dag_term = penalty_dag * alpha / num_samples
            penalty_dag_term += penalty_dag * penalty_dag * rho / (2 * num_samples)
        else:
            raise NotImplementedError()

        if train_config_dict["anneal_entropy"] == "linear":
            ELBO = log_p_term + entropy_filled_term + log_p_A_term - log_q_A_term / max(step - 5, 1) - penalty_dag_term
        elif train_config_dict["anneal_entropy"] == "noanneal":
            ELBO = log_p_term + entropy_filled_term + log_p_A_term - log_q_A_term - penalty_dag_term
        loss = -ELBO.sum() + reconstruction_term.sum()

        scale_batch_size = x.shape[0] / num_samples
        tracker["loss"] += loss.item()
        tracker["penalty_dag"] += penalty_dag.item() * scale_batch_size
        tracker["log_p_A_sparse"] += log_p_A_term.item() * x.shape[0]
        tracker["log_p_x"] += log_p_term.sum().item()
        tracker["h_filled"] += entropy_filled_term.sum().item()
        tracker["log_q_A"] += log_q_A_term.item() * x.shape[0]
        tracker["reconstruction"] += reconstruction_term.sum().item()
        return loss, tracker

    def initialize_loss_trackers(self) -> Dict:
        """
        Initialize trackers for different loss terms. Not really needed, but useful during development.
        """
        tracker = {
            "loss": 0,
            "penalty_dag": 0,
            "log_p_A_sparse": 0,
            "log_p_x": 0,
            "log_q_A": 0,
            "h_filled": 0,
            "reconstruction": 0,
        }
        return tracker

    def print_tracker(self, epoch: int, tracker: Dict) -> None:
        """
        Prints formatted contents of loss terms that are being tracked.
        """
        loss = tracker["loss"]
        log_p_x = tracker["log_p_x"]
        penalty_dag = tracker["penalty_dag"]
        log_p_A_sparse = tracker["log_p_A_sparse"]
        log_q_A = tracker["log_q_A"]
        h_filled = tracker["h_filled"]
        reconstr = tracker["reconstruction"]
        print(
            f"Epoch: {epoch}, loss: {loss:.2f}, log p(x|A): {log_p_x:.2f}, dag: {penalty_dag:.8f}, log p(A)_sp: {log_p_A_sparse:.2f}, log q(A): {log_q_A:.2f}, H filled: {h_filled:.2f}, rec: {reconstr:.3f}"
        )

    def get_auglag_penalty(self, tracker_dag_penalty: List) -> float:
        """
        Computes DAG penalty for augmented Lagrangian update step as the average of dag factors of binary
        adjacencies sampled during this inner optimization step.
        """
        return torch.mean(torch.Tensor(tracker_dag_penalty)).item()

    def _initialize_params_base_dist(self, train: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns the parameters of the base distribution. A Gaussian for now.

        Args:
            train: Whether the distribution's parameters are trainable or not.
        """
        mean = nn.Parameter(torch.zeros(self.input_dim, device=self._device), requires_grad=False)
        logscale = nn.Parameter(torch.zeros(self.input_dim, device=self._device), requires_grad=train)
        return mean, logscale

    def _log_prob_base(self, z: torch.Tensor) -> torch.Tensor:
        """
        Computes the log probability of the base distribution, given input z.

        Args:
            z: Array of size (input_dim) or (batch_size, input_dim), works both ways (i.e. single sample
            or batched).

        Returns:
            Log probability of samples. A number if z has shape (input_dim), of an array of
            shape (batch_size) is z has shape (batch_size, input_dim).
        """
        dist = td.Independent(td.Normal(self.mean_base, torch.exp(self.logscale_base)), 1)
        return dist.log_prob(z)

    def get_params_variational_distribution(
        self, x: torch.Tensor, mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Given a batch of samples x with missing values, returns the mean and scale of variational
        approximation over missing values.
        """
        x_obs = x * mask
        input_net = torch.cat([x_obs, mask], dim=1)  # Shape (batch_size, 2*input_dim)
        out = self.imputer_network(input_net)
        mean = out[:, : self.input_dim]  # Shape (batch_size, input_dim)
        logscale = out[:, self.input_dim :]  # Shape (batch_size, input_dim)
        logscale = torch.clip(logscale, min=-20, max=5)
        scale = torch.exp(logscale)
        return mean, scale

    def impute_and_compute_entropy(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Given a batch of samples x with missing values, returns the reparameterization filled in data
        and the entropy term for the variational bound.
        """
        # If fully observed do nothing (for efficicency)
        if (mask == 1).all():
            return x, torch.Tensor([0.0]).to(self._device), torch.Tensor([0.0]).to(self._device)
        mean, scale = self.get_params_variational_distribution(x, mask)
        var_dist = td.Normal(mean, scale)
        sample = var_dist.rsample()  # Shape (batch_size, input_dim)
        x_orig = x * mask
        x_filling = sample * (1 - mask)
        x_filled = x_orig + x_filling  # Shape (batch_size, input_dim)
        entropy = var_dist.entropy() * (1 - mask)  # Shape (batch_size, input_dim)
        entropy = entropy.sum(1)  # Shape (batch_size)
        return x_filled, entropy, mean

    def impute(
        self,
        data: np.ndarray,
        mask: np.ndarray,
        impute_config_dict: Optional[Dict[str, int]] = None,
        *,
        vamp_prior_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
        average: bool = True,
    ) -> np.ndarray:
        data = torch.from_numpy(data.astype(np.float32)).to(self._device)
        mask = torch.from_numpy(mask.astype(np.float32)).to(self._device)
        mean, scale = self.get_params_variational_distribution(data, mask)
        if average:
            sample = data * mask + mean * (1.0 - mask)
            return mean.detach().cpu().numpy()
        else:
            var_dist = td.Normal(mean, scale)
            samples = []
            for _ in range(impute_config_dict["sample_count"]):
                sample = data * mask + var_dist.sample() * (1.0 - mask)
                samples.append(sample)
            return torch.stack(samples).detach().cpu().numpy()

    def process_dataset(
        self, dataset: Dataset, train_config_dict: Dict[str, Any] = {}
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Generates the training data and mask.

        Args:
            dataset: Dataset to use.
            train_config_dict: Dictionary with training hyperparameters.

        Returns:
            Tuple with data and mask tensors.
        """
        self.data_processor._squash_input = False  # Avoid squashing data to [0, 1]
        processed_dataset = self.data_processor.process_dataset(dataset)
        data, mask = processed_dataset.train_data_and_mask
        # Should be careful if missing values?
        if train_config_dict["stardardize_data_mean"]:
            data = data - data.mean(axis=0)
        if train_config_dict["stardardize_data_std"]:
            data = data / data.std(axis=0)
        return data, mask

    def run_train(
        self,
        dataset: Dataset,
        train_config_dict: Dict[str, Any] = {},
        report_progress_callback: Optional[Callable[[str, int, int], None]] = None,
    ) -> None:
        """
        Runs training.
        """
        data, mask = self.process_dataset(dataset, train_config_dict)
        tensor_dataset = TensorDataset(*to_tensors(data, mask, device=self._device))
        dataloader = DataLoader(tensor_dataset, batch_size=train_config_dict["batch_size"], shuffle=True)
        if train_config_dict["opt_mode"] == "simple":
            self.run_train_simple(data.shape[0], dataloader, train_config_dict, report_progress_callback)
        elif train_config_dict["opt_mode"] == "auglag":
            self.run_train_auglag(data.shape[0], dataloader, train_config_dict, report_progress_callback)
        else:
            # Could implement QPM
            raise NotImplementedError()

    def run_train_simple(
        self,
        num_samples: int,
        dataloader,
        train_config_dict: Dict[str, Any] = {},
        report_progress_callback: Optional[Callable[[str, int, int], None]] = None,
    ) -> None:
        """
        Runs optimization with a single penalty parameter for DAGness, and hope for the best.
        This is to try simple stuff, should use auglag mode in practice.
        """
        # Create optimizer
        opt = torch.optim.Adam(self.parameters(), lr=train_config_dict["learning_rate"])
        # Train
        best_loss = np.nan
        for epoch in range(train_config_dict["epochs"]):
            tracker_loss_terms = self.initialize_loss_trackers()
            for x, mask_train_batch in dataloader:
                loss, tracker_loss_terms = self.compute_loss(
                    1, x, mask_train_batch, num_samples, tracker_loss_terms, train_config_dict
                )
                # Opt step
                opt.zero_grad()
                loss.backward()
                opt.step()

            # Save if loss improved
            if np.isnan(best_loss) or tracker_loss_terms["loss"] < best_loss:
                best_loss = tracker_loss_terms["loss"]
                best_epoch = epoch
                self.save()

            if np.isnan(tracker_loss_terms["loss"]):
                print("Loss is nan, I'm done %i." % epoch, flush=True)
                break

            if report_progress_callback is not None:
                report_progress_callback(self.model_id, epoch + 1, train_config_dict["epochs"])

            if epoch % 25 == 0:
                self.print_tracker(epoch, tracker_loss_terms)
        print(f"Best model found at epoch {best_epoch}, with Loss {best_loss:.2f}")

    def run_train_auglag(
        self,
        num_samples: int,
        dataloader,
        train_config_dict: Dict[str, Any] = {},
        report_progress_callback: Optional[Callable[[str, int, int], None]] = None,
    ) -> None:
        """
        Runs optimization with augmented lagrangian. This function has the outer loop,
        inner loop implmented by each class.
        """
        print_metrics = False
        print_adjacency_info = False
        rho = train_config_dict["rho"]
        alpha = train_config_dict["alpha"]
        progress_rate = train_config_dict["progress_rate"]
        self.opt = torch.optim.Adam(self.parameters(), lr=train_config_dict["learning_rate"], weight_decay=1e-5)
        # Outer optimization loop
        dag_penalty_prev = None
        num_below_tol = 0
        num_max_rho = 0
        num_not_done = 0
        for step in range(train_config_dict["max_steps_auglag"]):
            if num_below_tol >= 5 or num_max_rho >= 3:
                break
            if rho >= train_config_dict["safety_rho"]:
                num_max_rho += 1
            print("Step: %i" % step)
            # Optimize adjacency for fixed rho and alpha
            done_inner, dag_penalty = self.optimize_inner_auglag(
                rho, alpha, step, num_samples, dataloader, train_config_dict
            )
            print("Dag penalty after inner: %.10f" % dag_penalty)
            # Update alpha (and possibly rho) if inner optimization done
            if done_inner or num_not_done == 1:
                num_not_done = 0
                if dag_penalty < train_config_dict["tol_dag"]:
                    num_below_tol += 1
                if report_progress_callback is not None:
                    report_progress_callback(self.model_id, step + 1, train_config_dict["max_steps_auglag"])
                with torch.no_grad():
                    if print_adjacency_info:
                        adj_matrix = self.get_adj_matrix()
                        if len(adj_matrix.shape) == 3:
                            adj_matrix = adj_matrix.mean(0)
                        print(adj_matrix)
                    if dag_penalty_prev is None:
                        dag_penalty_prev = dag_penalty
                    else:
                        if dag_penalty > dag_penalty_prev * progress_rate:
                            print("Updating rho, dag penalty prev: %.10f" % dag_penalty_prev)
                            rho *= 10.0
                        else:
                            print("Updating alpha.")
                            dag_penalty_prev = dag_penalty
                            alpha += rho * dag_penalty
                            if dag_penalty == 0.0:
                                alpha *= 5
                        if rho >= train_config_dict["safety_rho"]:
                            alpha *= 5
                        rho = min([rho, train_config_dict["safety_rho"]])
                        alpha = min([alpha, train_config_dict["safety_alpha"]])
            else:
                num_not_done += 1
                print("Not done inner optimization.")

            self.save()

            if dag_penalty_prev is not None:
                print("Dag penalty: %.15f" % dag_penalty)
                print("Rho: %.2f, alpha: %.2f" % (rho, alpha))

            if print_metrics:
                with torch.no_grad():
                    pred_adj = self.get_adj_matrix()
                    print_metrics_fn(pred_adj)

    def optimize_inner_auglag(
        self, rho: float, alpha: float, step: int, num_samples: int, dataloader, train_config_dict: Dict[str, Any] = {},
    ) -> Tuple[bool, float]:
        """
        Optimizes for a given alpha and rho.
        """

        def get_lr():
            for param_group in self.opt.param_groups:
                return param_group["lr"]

        def set_lr(factor):
            for param_group in self.opt.param_groups:
                param_group["lr"] = param_group["lr"] * factor

        def initialize_lr(val):
            for param_group in self.opt.param_groups:
                param_group["lr"] = val

        lim_updates_down = 3
        num_updates_lr_down = 0
        initialize_lr(train_config_dict["learning_rate"])
        print("LR:", get_lr())
        best_loss = np.nan
        last_updated = -1
        done_opt = False
        tracker_dag_penalty = []
        for epoch in range(train_config_dict["max_epochs_per_step"]):
            tracker_loss_terms = self.initialize_loss_trackers()
            for x, mask_train_batch in dataloader:
                input_mask, scoring_mask = get_input_and_scoring_masks(
                    mask_train_batch,
                    max_p_train_dropout=train_config_dict["max_p_train_dropout"],
                    score_imputation=True,
                    score_reconstruction=True,
                )
                loss, tracker_loss_terms = self.compute_loss(
                    step,
                    x,
                    mask_train_batch,
                    input_mask,
                    num_samples,
                    tracker_loss_terms,
                    train_config_dict,
                    alpha,
                    rho,
                )
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
            tracker_dag_penalty.append(tracker_loss_terms["penalty_dag"])

            # Save if loss improved
            if np.isnan(best_loss) or tracker_loss_terms["loss"] < best_loss:
                best_loss = tracker_loss_terms["loss"]
                best_epoch = epoch

            # Check if has to reduce step size
            if epoch >= best_epoch + 50 and epoch >= last_updated + 50:
                last_updated = epoch
                num_updates_lr_down += 1
                set_lr(0.1)
                print("Reducing lr to %.5f" % get_lr())
                if num_updates_lr_down >= 2:
                    done_opt = True
                if num_updates_lr_down >= lim_updates_down:
                    done_opt = True
                    print("Exiting at epoch %i." % epoch)
                    break

            if epoch >= best_epoch + 200:
                done_opt = True
                print("Exiting at epoch %i." % epoch)
                break

            if np.isnan(tracker_loss_terms["loss"]):
                print(tracker_loss_terms)
                print("Loss is nan, I'm done.", flush=True)
                break

            if epoch % 50 == 0:
                self.print_tracker(epoch, tracker_loss_terms)
        print(f"Best model found at epoch {best_epoch}, with Loss {best_loss:.2f}")
        return done_opt, self.get_auglag_penalty(tracker_dag_penalty)

