import sbi
from sbi.inference import SNLE, SNRE, SNPE


from markovsbi.bm.api_utils import NPEModel, SBIModel
from markovsbi.utils.prior_utils import Normal

import torch
import numpy as np
import torch.distributions as dist


def run_factorized_nle_or_nre(cfg, task, data, method="nle"):
    if method == "nle":
        method = SNLE
    elif method == "nre":
        method = SNRE
    else:
        raise ValueError(f"Unknown method {method}")

    thetas = data["thetas"]
    xs = data["xs"]
    T = int(xs.shape[1])
    xs = xs.reshape(xs.shape[0], -1)

    # TODO: This would be the correct think to do
    # # Flatten xs
    # T = int(xs.shape[1])
    # # Assume xs is markov we only need to model density of last xs
    # cond_xs = xs[:, :-1, :]
    # cond_xs = cond_xs.reshape(-1, cond_xs.shape[-1])

    # # We have to extend thetas to also get cond_xs
    # thetas = torch.concatenate([thetas, cond_xs], dim=-1)
    # xs = xs[:, -1, :]

    # Setup inference
    prior = task.get_prior()
    prior_torch = convert_prior_to_torch(prior)

    init_params = dict(cfg.method.params_init)
    inf = method(prior=prior_torch, **init_params)

    # Perform training
    inf = inf.append_simulations(thetas, xs)
    train_kwargs = dict(cfg.method.params_train)
    _ = inf.train(**train_kwargs)

    # Build posterior
    posterior_kwargs = dict(cfg.method.params_build_posterior)
    posterior = inf.build_posterior(**posterior_kwargs)

    return SBIModel(posterior, T, cfg=cfg)


class TimeSeriesEmbeddingNet(torch.nn.Module):

    def __init__(self, input_dim, output_dim, num_layers=2):
        super(TimeSeriesEmbeddingNet, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.gru = torch.nn.GRU(
            self.input_dim, self.output_dim, num_layers=num_layers, batch_first=True
        )

    def forward(self, x):
        *batch_shape, n_samples, n_dim = x.shape
        x = x.view(-1, n_samples, n_dim)
        # Replace nans with zeros
        mask = torch.isnan(x).any(-1, keepdim=True).to(torch.int32)
        indices = torch.argmax(mask, dim=1).squeeze(1) - 1
        x = torch.nan_to_num(x, nan=0.0)
        hs = self.gru(x)[0]
        x = hs[torch.arange(hs.shape[0]), indices, :]
        x = x.view(*batch_shape, self.output_dim)
        return x


def run_npe_embedding_network(cfg, task, data, rng_method):

    method = SNPE

    # Data
    thetas = data["thetas"]
    xs = data["xs"]

    # Data augmentation for different sequence lengths
    T_max = xs.shape[1]
    for t in range(2, T_max):
        idx = torch.randint(
            0,
            xs.shape[0],
            (
                int(
                    cfg.method.subseq_data_augmentation_fraction
                    * cfg.task.num_simulations
                    / (T_max - 1)
                ),
            ),
        )
        xs_subset = xs[idx]
        theta_subset = thetas[idx]
        xs_subset[:, t:] = torch.nan  # Cut off data after t
        xs = torch.cat([xs, xs_subset], dim=0)
        thetas = torch.cat([thetas, theta_subset], dim=0)

    # Setup inference
    prior = task.get_prior()
    prior_torch = convert_prior_to_torch(prior)

    neural_net_params = cfg.method.neural_net

    if neural_net_params.name == "rnn":
        embedding_net = TimeSeriesEmbeddingNet(
            input_dim=xs.shape[-1],
            output_dim=neural_net_params.output_dim,
            num_layers=neural_net_params.num_layers,
        )
        density_estimator = sbi.utils.posterior_nn(
            model=cfg.method.params_init.density_estimator,
            embedding_net=embedding_net,
            z_score_x="none",
        )
    else:
        raise ValueError(f"Unknown neural net {neural_net_params.name}")

    inf = method(prior=prior_torch, density_estimator=density_estimator)

    # Perform training
    inf = inf.append_simulations(thetas, xs, exclude_invalid_x=False)
    train_kwargs = dict(cfg.method.params_train)
    _ = inf.train(**train_kwargs)

    # Build posterior
    posterior_kwargs = dict(cfg.method.params_build_posterior)
    posterior = inf.build_posterior(**posterior_kwargs)

    return NPEModel(posterior, cfg=cfg)


def convert_prior_to_torch(prior):
    if isinstance(prior, Normal):
        mu = torch.tensor(np.array(prior.mean))
        sigma = torch.tensor(np.array(prior.std))
        return dist.Independent(dist.Normal(mu, sigma), 1)
    else:
        raise ValueError(f"Unknown prior {prior}")
