from copy import deepcopy
from typing import Optional

import networkx as nx
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

from nflows.distributions import StandardNormal
from nflows.flows import Flow
from nflows.nn.nets import ResidualNet
from nflows.transforms import PointwiseAffineTransform, Transform, AffineCouplingTransform, CompositeTransform
from nflows.transforms.splines import unconstrained_rational_quadratic_spline
from nflows.utils import torchutils


class ParameterizedAffineTransform(PointwiseAffineTransform):
    def __init__(self, n_dim):
        super().__init__()

        delattr(self, "_shift")  # Unregister parameter
        delattr(self, "_scale")  # Unregister parameter

        shift_init = torch.randn(n_dim) / math.sqrt(n_dim)  # torch.zeros(n_dim)
        scale_init = torch.ones(n_dim)

        shift = nn.Parameter(shift_init)
        scale = nn.Parameter(scale_init)

        self.register_parameter("_shift", shift)
        self.register_parameter("_scale", scale)


class GroupAffineTransform(PointwiseAffineTransform):
    def __init__(self, group_mask):
        # Consider groups for the dag
        super().__init__()

        delattr(self, "_shift")  # Unregister parameter
        delattr(self, "_scale")  # Unregister parameter

        # group_mask: tensor with an index (group identifier) for each dimension
        self.group_mask = group_mask
        n_groups = len(torch.unique(self.group_mask))

        shift_init = torch.zeros(n_groups)
        scale_init = torch.ones(n_groups)

        self.group_shift = nn.Parameter(shift_init)
        self.group_scale = nn.Parameter(scale_init)

    @property
    def _scale(self):
        return self.group_scale[self.group_mask]

    @property
    def _shift(self):
        return self.group_shift[self.group_mask]


class GroupSplineTransform(Transform):
    def __init__(self, group_mask, n_knots=5, min_tail_bound=1e-3):
        super().__init__()

        self.group_mask = group_mask
        self.n_groups = len(torch.unique(self.group_mask))
        self.group_sizes = [int(sum((self.group_mask == g).long())) for g in range(self.n_groups)]

        log_tail_bounds_init = torch.zeros(self.n_groups)
        unnormalized_widths_init = torch.zeros(self.n_groups, 1, n_knots)
        unnormalized_heights_init = torch.zeros(self.n_groups, 1, n_knots)
        unnormalized_derivatives_init = torch.zeros(self.n_groups, 1, n_knots)

        self.min_tail_bound = min_tail_bound
        self.log_tail_bounds = nn.Parameter(log_tail_bounds_init)
        self.unnormalized_widths = nn.Parameter(unnormalized_widths_init)
        self.unnormalized_heights = nn.Parameter(unnormalized_heights_init)
        self.unnormalized_derivatives = nn.Parameter(unnormalized_derivatives_init)

    @property
    def tail_bounds(self):
        return torch.exp(self.log_tail_bounds) + self.min_tail_bound

    def forward(self, inputs: torch.Tensor, context=Optional[torch.Tensor]):
        # Warning: nflows RQSplines are implemented in a way that computations become slow with large batches
        # This is why we split inputs into batches of 50
        outputs = torch.zeros_like(inputs)
        logabsdet = torch.zeros(len(inputs))

        for i, batch in enumerate(inputs.split(50)):
            for group_id in range(self.n_groups):
                group_outputs, group_logabsdet = unconstrained_rational_quadratic_spline(
                    batch[:, self.group_mask == group_id],
                    self.unnormalized_widths[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id], 1),
                    self.unnormalized_heights[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id], 1),
                    self.unnormalized_derivatives[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id],
                                                                                1),
                    tail_bound=self.tail_bounds[group_id]
                )
                outputs[i * len(batch):(i + 1) * len(batch), self.group_mask == group_id] = group_outputs
                logabsdet[i * len(batch):(i + 1) * len(batch)] += group_logabsdet.sum(dim=1)
        return outputs, logabsdet

    def inverse(self, inputs: torch.Tensor, context=Optional[torch.Tensor]):
        outputs = torch.zeros_like(inputs)
        logabsdet = torch.zeros(len(inputs))

        for i, batch in enumerate(inputs.split(50)):
            for group_id in range(self.n_groups):
                group_outputs, group_logabsdet = unconstrained_rational_quadratic_spline(
                    batch[:, self.group_mask == group_id],
                    self.unnormalized_widths[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id], 1),
                    self.unnormalized_heights[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id], 1),
                    self.unnormalized_derivatives[group_id].unsqueeze(0).repeat(len(batch), self.group_sizes[group_id],
                                                                                1),
                    tail_bound=self.tail_bounds[group_id],
                    inverse=True
                )
                outputs[i * len(batch):(i + 1) * len(batch), self.group_mask == group_id] = group_outputs
                logabsdet[i * len(batch):(i + 1) * len(batch)] += group_logabsdet.sum(dim=1)
        return outputs, logabsdet


class PointwiseSplineTransform(GroupSplineTransform):
    def __init__(self, n_dim, n_knots=5, min_tail_bound=1e-3):
        pointwise_mask = torch.arange(n_dim)
        super().__init__(pointwise_mask, n_knots=n_knots, min_tail_bound=min_tail_bound)


class DAGNet(nn.Module):
    def __init__(self,
                 n_dim,
                 child_dims: torch.Tensor,
                 parent_dims: torch.Tensor,
                 n_coupling_params=2,  # 2 for affine coupling.
                 hidden_features=None,
                 num_blocks=1,
                 activation=F.relu):
        super().__init__()

        self.n_dim = n_dim
        self.child_dims = child_dims  # e.g. [1, 2, 6, 7]
        self.parent_dims = parent_dims  # e.g. [0, 3, 4]
        self.n_coupling_params = n_coupling_params  # parameters for coupling transform
        self.parent_mask = torch.isin(torch.arange(self.n_dim), parent_dims)
        self.child_mask = torch.isin(torch.arange(self.n_dim), child_dims)

        self.net = ResidualNet(
            in_features=len(self.parent_dims),
            out_features=self.n_coupling_params,
            hidden_features=len(self.parent_dims) if hidden_features is None else hidden_features,
            activation=activation,
            num_blocks=num_blocks
        )

    def forward(self, inputs, context=None):
        parameters = torch.zeros(len(inputs), self.n_dim, self.n_coupling_params)
        parameters[:, self.child_mask, :] += self.net(inputs[:, self.parent_mask], context).unsqueeze(1)
        return parameters


class MaskedAffineCouplingTransform(Transform):
    def __init__(self, n_dim, child_dims, parent_dims, **kwargs):
        super().__init__()
        self.features = n_dim
        self.transform_net = DAGNet(n_dim, child_dims, parent_dims, n_coupling_params=2, **kwargs)

    def forward(self, inputs, context=None):
        if inputs.dim() not in [2]:
            raise ValueError("Inputs must be a 2D tensor.")
        transform_params = self.transform_net(inputs, context)
        outputs, logabsdet = self._coupling_transform_forward(inputs=inputs, transform_params=transform_params)
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        if inputs.dim() not in [2]:
            raise ValueError("Inputs must be a 2D tensor.")
        transform_params = self.transform_net(inputs, context)
        outputs, logabsdet = self._coupling_transform_inverse(inputs=inputs, transform_params=transform_params)
        return outputs, logabsdet

    def _scale_and_shift(self, transform_params):
        unconstrained_scale = transform_params[..., 0]
        shift = transform_params[..., 1]
        # scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
        # scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3
        scale = torch.exp(unconstrained_scale)
        return scale, shift

    def _coupling_transform_forward(self, inputs, transform_params):
        scale, shift = self._scale_and_shift(transform_params)
        log_scale = torch.log(scale)
        outputs = inputs * scale + shift
        logabsdet = torchutils.sum_except_batch(log_scale, num_batch_dims=1)
        return outputs, logabsdet

    def _coupling_transform_inverse(self, inputs, transform_params):
        scale, shift = self._scale_and_shift(transform_params)
        log_scale = torch.log(scale)
        outputs = (inputs - shift) / scale
        logabsdet = -torchutils.sum_except_batch(log_scale, num_batch_dims=1)
        return outputs, logabsdet


class ZCA(Transform):
    def __init__(self):
        super().__init__()
        self.A: torch.Tensor = torch.tensor(0.0)
        self.A_inv: torch.Tensor = torch.tensor(0.0)
        self.logabsdet_forward: torch.Tensor = torch.tensor(0.0)
        self.logabsdet_inverse: torch.Tensor = torch.tensor(0.0)

    def fit(self, x):
        """
        Function to compute ZCA whitening matrix (aka Mahalanobis whitening).
        INPUT:  X: [M x N] matrix.
            Rows: Variables
            Columns: Observations
        OUTPUT: ZCAMatrix: [M x M] matrix
        """
        # Covariance matrix [column-wise variables]: Sigma = (X-mu)' * (X-mu) / N
        sigma = np.cov(x.numpy().T, rowvar=True)  # [M x M]
        # Singular Value Decomposition. X = U * np.diag(S) * V
        U, S, V = np.linalg.svd(sigma)
        # U: [M x M] eigenvectors of sigma.
        # S: [M x 1] eigenvalues of sigma.
        # V: [M x M] transpose of U
        # Whitening constant: prevents division by zero
        epsilon = 1e-10
        # ZCA Whitening matrix: U * Lambda * U'
        self.A = torch.tensor(np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + epsilon)), U.T))).float()  # [M x M]
        self.A_inv = torch.linalg.inv(self.A).float()
        self.logabsdet_forward = torch.log(torch.abs(torch.linalg.det(self.A)))
        self.logabsdet_inverse = -self.logabsdet_forward

    def forward(self, inputs: torch.Tensor, context=None):
        outputs = (self.A @ inputs.T).T
        logabsdet = torch.ones(len(inputs)) * self.logabsdet_forward
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        outputs = (self.A_inv @ inputs.T).T
        logabsdet = torch.ones(len(inputs)) * self.logabsdet_inverse
        return outputs, logabsdet


class Standardization(Transform):
    def __init__(self):
        super().__init__()
        self.mu = 0.0
        self.std = 1.0
        self.logabsdet_forward = 0.0
        self.logabsdet_inverse = 0.0

    def fit(self, x):
        self.mu = torch.mean(x, dim=0)
        self.std = torch.std(x, dim=0)
        self.logabsdet_forward = -torch.sum(torch.log(self.std))
        self.logabsdet_inverse = -self.logabsdet_forward

    def forward(self, inputs, context=None):
        outputs = (inputs - self.mu) / self.std
        logabsdet = torch.ones(len(inputs)) * self.logabsdet_forward
        return outputs, logabsdet

    def inverse(self, inputs, context=None):
        outputs = inputs * self.std + self.mu
        logabsdet = torch.ones(len(inputs)) * self.logabsdet_inverse
        return outputs, logabsdet


def flatten(t):
    return [item for sublist in t for item in sublist]


def build_ecoflow(graph,
                  pointwise_transform_type='affine',
                  additional_coupling=None,
                  n_additional_coupling_layers=2,
                  whitening_data=None,
                  zca_cond_threshold=1000,
                  **kwargs):
    # Constructs the flow from the last to the first layer

    def make_pointwise_layer(*args, **kwargs):
        if pointwise_transform_type == 'affine':
            return ParameterizedAffineTransform(n_dim)
        elif pointwise_transform_type == 'spline':
            return PointwiseSplineTransform(n_dim)
        elif pointwise_transform_type == 'group_affine':
            return GroupAffineTransform(*args, **kwargs)
        elif pointwise_transform_type == 'group_spline':
            return GroupSplineTransform(*args, **kwargs)
        else:
            raise NotImplementedError(f"{pointwise_transform_type} pointwise transform not implemented")

    n_dim = sum([graph.nodes[v]['n_dim'] for v in graph.nodes])
    G = deepcopy(graph)
    flow_layers = []
    pointwise_group_mask = torch.tensor(
        flatten([[i] * graph.nodes[v]['n_dim'] for i, v in enumerate(graph.nodes)]),
        dtype=torch.long
    )
    while (leaves := [x for x in G.nodes() if G.out_degree(x) == 0]):
        # While there are leaves in the graph
        for leaf in leaves:
            parents = list(G.predecessors(leaf))
            if not parents:
                G.remove_node(leaf)
                continue
            # print(f'Adding coupling layer for {leaf}')
            child_dims = torch.tensor(G.nodes[leaf]['indices'], dtype=torch.long)
            parent_dims_set = set()
            for p in parents:
                parent_mask = nx.get_edge_attributes(G, 'mask')[(p, leaf)]
                for index in parent_mask:
                    parent_dims_set.add(G.nodes[p]['indices'][index])
            parent_dims = torch.tensor(sorted(list(parent_dims_set)), dtype=torch.long)
            coupling_layer = MaskedAffineCouplingTransform(
                n_dim=n_dim,
                child_dims=child_dims,
                parent_dims=parent_dims,
                **kwargs
            )
            pointwise_transform = make_pointwise_layer(pointwise_group_mask)
            flow_layers.insert(0, pointwise_transform)
            flow_layers.insert(0, coupling_layer)
            G.remove_node(leaf)
    flow_layers.insert(0, make_pointwise_layer(pointwise_group_mask))  # An initial transform to deal with leaves
    if additional_coupling == 'affine':
        mask = torch.ones(n_dim)
        mask[::2] = -1
        for _ in range(n_additional_coupling_layers):
            mask *= -1
            flow_layers.append(AffineCouplingTransform(
                mask=mask,
                transform_net_create_fn=lambda in_features, out_features: ResidualNet(
                    in_features,
                    out_features,
                    hidden_features=n_dim,
                    num_blocks=2
                )
            ))

    # If whitening data is not None, compute a ZCA whitening matrix and insert it at the start of the flow
    if whitening_data is not None:
        # Standardize data first

        standardization = Standardization()
        standardization.fit(whitening_data)
        standardized_data = standardization.forward(whitening_data)[0]

        zca = ZCA()
        zca.fit(standardized_data)

        # Check if condition number is sufficiently small. If not, do NOT add ZCA to the flow.
        # Standardization will still be added
        if (zca_cond := torch.linalg.cond(zca.A)) < zca_cond_threshold:
            flow_layers.insert(0, zca)
            flow_layers.insert(0, standardization)
        else:
            flow_layers.insert(0, standardization)
            print(f'ZCA matrix condition number too large ({float(zca_cond):.2f}). '
                  f'ZCA will be skipped. '
                  f'Increasing zca_cond_threshold lets you avoid this, but may cause instabilities.')

    base_distribution_object = StandardNormal(shape=[n_dim])
    flow = Flow(transform=CompositeTransform(flow_layers), distribution=base_distribution_object)
    return flow


if __name__ == '__main__':
    # Funnel example
    from nfmc_jax.hierarchical.problems import Funnel
    import torch.optim as optim
    from tqdm import tqdm
    import matplotlib.pyplot as plt

    torch.manual_seed(0)

    problem = Funnel(n_dim=1000)
    x_train = problem.sample(10)
    x_val = problem.sample(10)
    x_test = problem.sample(1000)

    dag = problem.graph(draw=True)
    flow = build_ecoflow(dag)

    print(f'Number of trainable parameters: {sum([p.numel() for p in flow.parameters() if p.requires_grad])}')

    patience = 1000
    n_epochs: int = 20000

    optimizer = optim.AdamW(flow.parameters(), lr=1e-1, weight_decay=1e-4)
    best_state = deepcopy(flow.state_dict())
    best_loss_val = torch.inf
    best_epoch = 0
    for epoch in (pbar := tqdm(range(n_epochs))):
        optimizer.zero_grad()
        loss = -flow.log_prob(x_train).mean()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            loss_val = -flow.log_prob(x_val).mean()

        if loss_val < best_loss_val:
            best_loss_val = loss_val
            best_state = deepcopy(flow.state_dict())
            best_epoch = epoch
        if epoch - best_epoch > patience:
            break
        if torch.isnan(loss) or torch.isnan(loss_val) or torch.isinf(loss) or torch.isinf(loss_val):
            break

        pbar.set_postfix_str(
            f'Train loss: {float(loss):.4f} | '
            f'Val loss: {float(loss_val):.4f} [{float(best_loss_val):.4f} @ {best_epoch}]'
        )
    flow.load_state_dict(best_state)
    flow.eval()

    with torch.no_grad():
        flow_samples = flow.sample(1000)

    print(f'Test logp: {flow.log_prob(x_test).mean()}')

    fig, ax = plt.subplots()
    dim0 = 0
    dim1 = 1
    ax.scatter(x_test[:, dim0], x_test[:, dim1], label='Test')
    ax.scatter(flow_samples[:, dim0], flow_samples[:, dim1], label='EcoFlow')
    ax.scatter(x_train[:, dim0], x_train[:, dim1], label='Train', ec='k', s=2 ** 7)
    ax.scatter(x_val[:, dim0], x_val[:, dim1], label='Validation', ec='k', s=2 ** 7)
    ax.set_xscale('symlog')
    ax.set_yscale('symlog')
    ax.set_xlabel(f'Dimension {dim0}')
    ax.set_ylabel(f'Dimension {dim1}')
    ax.legend()
    ax.set_title('EcoFlow performance on the funnel')
    fig.tight_layout()
    plt.show()

    # print(f'Log posterior on test data: {problem.log_posterior(x_test).mean()}')
    # print(f'Log posterior on flow samples: {problem.log_posterior(flow_samples).mean()}')
    # print(f'Flow log probability on test data: {flow.log_prob(x_test).mean()}')
    # print(f'Flow log probability on flow samples: {flow.log_prob(flow_samples).mean()}')
