import abc

import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.distributions as D
import matplotlib.pyplot as plt

from nfmc_jax.utils.torch_distributions import Funnel as _Funnel
import stan_jupyter as stan  # pip install pystan-jupyter


class AbstractProblem(abc.ABC):
    def __init__(self, n_dim):
        self.n_dim = n_dim

    @abc.abstractmethod
    def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
        # Compute log likelihood values for x
        # parameters x are assumed to be unconstrained, i.e. on the real number line
        pass

    @abc.abstractmethod
    def log_prior(self, x: torch.Tensor) -> torch.Tensor:
        # Compute log prior values for x
        # parameters x are assumed to be unconstrained, i.e. on the real number line
        pass

    @abc.abstractmethod
    def logj(self, x: torch.Tensor) -> torch.Tensor:
        # Compute logj for x
        # parameters x are assumed to be unconstrained, i.e. on the real number line
        pass

    def log_posterior(self, x: torch.Tensor) -> torch.Tensor:
        # Compute unnormalized log posterior values for x
        # parameters x are assumed to be unconstrained, i.e. on the real number line
        return self.log_likelihood(x) + self.log_prior(x) + self.logj(x)

    @abc.abstractmethod
    def graph(self) -> nx.DiGraph:
        # Return the problem's DAG
        pass

    @abc.abstractmethod
    def sample(self, n: int = None):
        # Obtain ground truth samples from the posterior (using HMC or otherwise)
        pass

    @staticmethod
    def assign_dag_indices(g: nx.DiGraph):
        dim_counter = 0
        for node in g.nodes:
            g.nodes[node]['indices'] = list(range(dim_counter, dim_counter + g.nodes[node]['n_dim']))
            dim_counter += g.nodes[node]['n_dim']


class Funnel(AbstractProblem):
    def __init__(self, n_dim):
        super().__init__(n_dim)
        self.dist = _Funnel(n_dim)

    def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tensor(0.0)

    def logj(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tensor(0.0)

    def log_prior(self, x: torch.Tensor) -> torch.Tensor:
        return self.dist.log_prob(x)

    def graph(self, draw=False, **kwargs) -> nx.DiGraph:
        g = nx.DiGraph()

        # Define the nodes
        g.add_node('x0', n_dim=1)
        g.add_node('x1', n_dim=self.n_dim - 1)

        self.assign_dag_indices(g)

        # Define the edges
        g.add_edge('x0', 'x1', mask=[0])

        if draw:
            labels = nx.get_node_attributes(g, 'n_dim')
            labels = {k: f'{k}: {v}' for k, v in labels.items()}
            pos = dict(
                x0=(0, 1),
                x1=(0, 0)
            )
            plt.figure()
            nx.draw(g, with_labels=True, labels=labels, node_size=3500, pos=pos, font_color='white', **kwargs)
            nx.draw_networkx_edge_labels(g, pos=pos)
            plt.ylim(-0.25, 1.25)
            plt.show()

        return g

    def sample(self, n: int = 1000):
        return self.dist.sample(n)


class Funnel2(Funnel):
    def __init__(self, n_dim: int, scale: float = 1.0):
        # Funnel as described in the DLMC paper
        super().__init__(n_dim=n_dim)
        self.scale = scale
        self.likelihood_mean = D.Independent(D.Normal(loc=torch.randn(n_dim), scale=scale), 1).sample()[1:]

    def log_likelihood(self, x: torch.Tensor) -> torch.Tensor:
        return D.Independent(D.Normal(loc=self.likelihood_mean, scale=self.scale), 1).log_prob(x[:, 1:])

    def sample(self, n: int = 1000, n_chains: int = 4):
        model_code = """
        data {
          int<lower=1> d;
          vector[d-1] mu;
          real<lower=0> scale;
        }

        parameters {
          real theta;
          vector[d-1] z;
        }

        model {
          // Prior
          theta ~ normal(0, 3);
          z ~ normal(0, exp(theta / 2));
          
          // Likelihood
          z ~ normal(mu, scale);
        }
        """

        model_data = {
            "d": self.n_dim,
            "mu": self.likelihood_mean.numpy().tolist(),
            "scale": self.scale
        }

        posterior = stan.build(model_code, data=model_data, random_seed=0)
        fit = posterior.sample(num_chains=n_chains, num_samples=n)
        df = fit.to_frame()  # return pandas dataframe
        return df


class GermanCredit(AbstractProblem):
    def __init__(self, dataset_path):
        # https://arxiv.org/abs/1903.03704
        # Download the dataset here: https://archive.ics.uci.edu/ml/machine-learning-databases/statlog/german/
        # tau, lambda (25), beta (25)
        super().__init__(n_dim=51)
        self.x = None
        self.y = None
        self.load_data(dataset_path)

    def load_data(self, dataset_path):
        data = np.array(pd.read_table(dataset_path, header=None, delim_whitespace=True)).astype(np.float32)
        x = data[:, :-1]
        y = (data[:, -1] - 1).astype(np.int32)
        x_min = np.min(x, 0, keepdims=True)
        x_max = np.max(x, 0, keepdims=True)
        x /= (x_max - x_min)
        x = 2.0 * x - 1.0
        x = np.concatenate([x, np.ones([x.shape[0], 1])], -1)

        self.x = torch.tensor(x, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)  # Torch wants floats, even if we have a Bernoulli likelihood

    def log_prior(self, unconstrained_parameters: torch.Tensor):
        log_tau = unconstrained_parameters[:, 0]
        log_lmbd = unconstrained_parameters[:, 1:26]
        beta = unconstrained_parameters[:, 26:51]

        tau = torch.exp(log_tau)
        lmbd = torch.exp(log_lmbd)

        log_prob_tau = D.Gamma(0.5, 0.5).log_prob(tau)
        log_prob_lmbd = torch.sum(D.Gamma(0.5, 0.5).log_prob(lmbd), dim=1)
        log_prob_beta = torch.sum(D.Normal(0, 1).log_prob(beta), dim=1)

        log_prob = log_prob_tau + log_prob_lmbd + log_prob_beta
        return log_prob

    def log_likelihood(self, unconstrained_parameters: torch.Tensor):
        log_tau = unconstrained_parameters[:, 0]
        log_lmbd = unconstrained_parameters[:, 1:26]
        beta = unconstrained_parameters[:, 26:51]

        tau = torch.exp(log_tau)
        lmbd = torch.exp(log_lmbd)

        loc = torch.sigmoid((self.x @ (tau.view(-1, 1) * beta * lmbd).T).T)
        log_prob = D.Independent(D.Bernoulli(loc), 1).log_prob(self.y)
        return log_prob

    def logj(self, unconstrained_parameters: torch.Tensor):
        log_tau = unconstrained_parameters[:, 0]
        log_lmbd = unconstrained_parameters[:, 1:26]

        logj_tau = log_tau
        logj_lmbd = torch.sum(log_lmbd, dim=1)

        logj = logj_tau + logj_lmbd
        return logj

    def graph(self, draw=False, **kwargs) -> nx.DiGraph:
        g = nx.DiGraph()

        # Define the nodes
        g.add_node('tau', n_dim=1)
        g.add_node('lambda', n_dim=25)
        g.add_node('beta', n_dim=25)

        self.assign_dag_indices(g)

        if draw:
            labels = nx.get_node_attributes(g, 'n_dim')
            labels = {k: f'{k}: {v}' for k, v in labels.items()}
            pos = {
                'tau': (-1, 0),
                'lambda': (0, 0),
                'beta': (1, 0),
            }
            plt.figure()
            nx.draw(g, with_labels=True, labels=labels, node_size=7000, pos=pos, font_color='white', **kwargs)
            plt.xlim(-2, 2)
            plt.ylim(-1, 1)
            plt.show()

        return g

    def sample(self, n_chains: int = 4, n: int = 1000, **kwargs):
        model_code = """
        data {
          int<lower=0> N;
          matrix[N, 25] x;
          array[N] int y;
        }
        
        parameters {
          real<lower=0> tau;
          vector<lower=0>[25] lambda;
          vector[25] beta;
        }
        
        model {
          tau ~ gamma(0.5, 0.5);
          lambda ~ gamma(0.5, 0.5);
          beta ~ normal(0, 1);
          
          // This is much faster than the for loop
          y ~ bernoulli_logit(x * (tau * lambda .* beta));
          
          // for (i in 1:N) {
          //   y[i] ~ bernoulli_logit(dot_product(x[i], tau * lambda .* beta));
          // }
        }
        """

        model_data = {
            "N": len(self.x),
            "x": self.x.numpy().tolist(),
            "y": self.y.numpy().astype(np.int32).tolist()
        }

        posterior = stan.build(model_code, data=model_data, random_seed=0)
        fit = posterior.sample(num_chains=n_chains, num_samples=n)
        df = fit.to_frame()  # return pandas dataframe
        return df


if __name__ == '__main__':
    problem = Funnel2(100)
    dag = problem.graph(draw=True)
    posterior_samples = problem.sample()
    posterior_samples_torch = torch.column_stack([
        torch.tensor(posterior_samples['theta'].values).view(-1, 1),
        torch.tensor(posterior_samples.filter(regex=r'z\.').values)
    ])

    problem = GermanCredit(
        '/home/david/PycharmProjects/normalizing-flow-stability-analysis/datasets/german-data-numeric.tsv'
    )
    dag = problem.graph(draw=True)
    posterior_samples = problem.sample()
    print(posterior_samples.head())
