import os.path
import pickle
import gzip
from functools import reduce
from itertools import product
from copy import deepcopy
import pyro
import torch
from pyro.nn import PyroModule
from torch import nn
from torch.distributions import constraints
import pyro.distributions as dist
from pyro.poutine import trace, replay
from pyro.distributions.transforms import StickBreakingTransform, ExpTransform
from sbsep.scaler import Scaler
from sbsep.dist import bnn_helper
from sbsep.config import BNNConfig, PrimitiveCBNNState, BNNConfigFactory
import logging

logger = logging.getLogger(__name__)


class AbsModule(PyroModule):
    def __init__(self):
        super().__init__()
        self.rms = 0.0

    def record_error(self, obs, mu):
        c = nn.MSELoss()
        self.rms = torch.sqrt(c(obs.flatten(), mu.flatten())).item()

    def get_data_evidence(self, data_tup, num_samples=2000):
        evidence = []

        for _ in range(num_samples):
            guide_trace = trace(self.guide).get_trace(*data_tup)
            model_trace = trace(replay(self, guide_trace)).get_trace(*data_tup)

            obs_samples_names = []
            for k, v in model_trace.nodes.items():
                if "is_observed" in v and v["is_observed"]:
                    obs_samples_names += [k]

            sample = {
                name: {"prob": props["fn"].log_prob(props["value"]).sum()}
                for (name, props) in model_trace.nodes.items()
                if props["type"] == "sample" and name in obs_samples_names
            }

            evidence.append(
                reduce(
                    lambda a, b: a + b,
                    [sample[name]["prob"] for name in obs_samples_names],
                )
            )

        evidence_dist = -torch.stack(evidence)
        evidence_5pct = torch.quantile(evidence_dist, 0.05, axis=0).item()
        evidence_95pct = torch.quantile(evidence_dist, 0.95, axis=0).item()

        return torch.mean(evidence_dist).item(), evidence_5pct, evidence_95pct


class CBNN(AbsModule):
    filetype_suffix = "cbnn_serialized"

    map_transform_dist = {
        "IdTransform": dist.Normal,
        "ExpTransform": dist.LogNormal,
        "StickBreakingTransform": dist.Dirichlet,
    }
    map_observe_transform = {
        True: {
            "IdTransform": [],
            "ExpTransform": [],
            "StickBreakingTransform": [StickBreakingTransform()],
        },
        False: {
            "IdTransform": [],
            "ExpTransform": [ExpTransform()],
            "StickBreakingTransform": [StickBreakingTransform()],
        },
    }

    def __init__(
        self,
        config: BNNConfig,
        state: PrimitiveCBNNState = PrimitiveCBNNState(),
        observe=False,
        debug=False,
    ):
        super().__init__()
        self.observe = observe
        self.debug = debug

        self.config = config
        self.name = self.config.architecture.name
        self.dims = self.config.architecture.dims
        self.xa, self.xb = self.config.architecture.domain

        self.state = deepcopy(state)

        # record history, if necessary
        self._history = {}

        # TODO : compatible with observe mode

        self.nlayers = len(self.dims) - 1
        self.activation = nn.ReLU()

        self.transforms = self.map_observe_transform[observe][
            self.config.architecture.transform
        ]

        self.model_scale_scale = 1.0
        self.guide_scale_scale = 0.01

        self.sampling_dist = (
            self.map_transform_dist[self.config.architecture.transform]
            if self.config.architecture.transform in self.map_transform_dist
            else dist.Normal
        )

        # scaling
        self.scaler_xcoord = Scaler(self.xa, self.xb)

        if self.state.priors is None:
            self.state.priors = []

            for k, (da, db) in enumerate(zip(self.dims, self.dims[1:])):
                weight_loc = torch.zeros(da, db)
                weight_scale = self.model_scale_scale * torch.ones(da, db)
                bias_loc = torch.zeros(db)
                bias_scale = self.model_scale_scale * torch.ones(db)
                self.state.priors += [(weight_loc, weight_scale, bias_loc, bias_scale)]

        if self.state.var_init_value is None:
            self.state.var_init_value = []
            for k, (da, db) in enumerate(zip(self.dims, self.dims[1:])):
                weight_loc = torch.zeros(da, db)
                weight_scale = self.guide_scale_scale * torch.ones(da, db)
                bias_loc = torch.zeros(db)
                bias_scale = self.guide_scale_scale * torch.ones(db)
                self.state.var_init_value += [
                    (weight_loc, weight_scale, bias_loc, bias_scale)
                ]

    @property
    def history(self):
        return {self.config.architecture.name: self._history}

    def guess_parameter_names(self):
        wtypes = ["weight", "bias"]
        ptypes = ["loc", "scale"]
        lemmas = [f"{self.name}#{wt}_{pt}" for wt, pt in product(wtypes, ptypes)]
        ext_lemmas = [f"{lemma}_{ix}" for lemma, ix in product(lemmas, range(self.nlayers))]
        return ext_lemmas

    def forward(self, x0, ys=None):
        x = self.scaler_xcoord.scale(x0)

        mu = bnn_helper(
            self.state.priors, x, name=self.name, transforms=self.transforms
        )

        if self.observe:
            if self.config.architecture.transform == "StickBreakingTransform":
                # dimensionality tells us it should be sigma^2 here
                obs = pyro.sample(
                    f"{self.name}#obs",
                    self.sampling_dist(mu / self.config.train.sigma_obs),
                    obs=ys,
                )
                # self.record_error(obs, mu_trans)
            else:
                obs = pyro.sample(
                    f"{self.name}#obs",
                    self.sampling_dist(mu, self.config.train.sigma_obs),
                    obs=ys,
                )
                ts = self.map_observe_transform[False][
                    self.config.architecture.transform
                ]
                mu_trans = reduce(lambda res, f: f(res), ts, mu)
                self.record_error(obs, mu_trans)
        elif self.debug:
            if self.config.architecture.normalize:
                mu = self._normalize(mu)
            obs = pyro.sample(
                f"{self.name}#obs",
                dist.Delta(mu),
                obs=ys,
            )
            self.record_error(obs, mu)
        else:
            if self.config.architecture.normalize:
                mu = self._normalize(mu)
        return mu

    def guide(self, x0, ys=None):
        x = self.scaler_xcoord.scale(x0)

        gparams = []

        for k, (da, db, layer_var) in enumerate(
            zip(self.dims, self.dims[1:], self.state.var_init_value)
        ):
            (weight_loc, weight_scale, bias_loc, bias_scale) = layer_var

            guess_weight_loc = pyro.param(f"{self.name}#weight_loc_{k}", weight_loc)
            guess_weight_scale = pyro.param(
                f"{self.name}#weight_scale_{k}",
                weight_scale,
                constraint=constraints.positive,
            )

            guess_bias_loc = pyro.param(f"{self.name}#bias_loc_{k}", bias_loc)
            guess_bias_scale = pyro.param(
                f"{self.name}#bias_scale_{k}",
                bias_scale,
                constraint=constraints.positive,
            )
            gparams += [
                (guess_weight_loc, guess_weight_scale, guess_bias_loc, guess_bias_scale)
            ]

        mu = bnn_helper(gparams, x, name=self.name, transforms=self.transforms)
        return mu

    def _normalize(self, x):
        volume = self.xb - self.xa
        r = x.shape[0] * x / (x.sum(axis=0) * volume)
        return r

    def save(self, path=None):
        self.state.var_init_value = self.get_weights()

        obj_desc = {"state": self.state, "config": self.config.to_json()}

        if path is not None:
            filename = os.path.join(
                os.path.expanduser(path), f"{self.name}_{self.filetype_suffix}.pkl.gz"
            )

            with gzip.open(filename, "wb") as file:
                pickle.dump(obj_desc, file)
        else:
            return obj_desc

    def get_weights(self):
        name = self.name
        par_store = pyro.get_param_store()

        params = []
        for k in range(len(self.dims) - 1):
            weight_loc = par_store.get_param(f"{name}#weight_loc_{k}").clone().detach()
            weight_scale = (
                par_store.get_param(f"{name}#weight_scale_{k}").clone().detach()
            )
            bias_loc = par_store.get_param(f"{name}#bias_loc_{k}").clone().detach()
            bias_scale = par_store.get_param(f"{name}#bias_scale_{k}").clone().detach()
            params += [(weight_loc, weight_scale, bias_loc, bias_scale)]
        return params

    def save_state_to_history(self, epoch, override=False):
        if override or (
            self.config.train.save_state_each is not None
            and (epoch % self.config.train.save_state_each) == 0
        ):
            params = self.get_weights()
            self._history[epoch] = params
            logger.debug(
                f" epoch = {epoch}, mean layer1 weight loc = {torch.mean(params[1][0]):.3e}"
            )

    @classmethod
    def load_from_file(
        cls,
        path,
        name,
        prior_variance_scale=1.0,
        init_variance_scale=1.0,
        load_var_init_as_prior=False,
        **kwargs,
    ):
        filename = os.path.join(
            os.path.expanduser(path), f"{name}_{cls.filetype_suffix}.pkl.gz"
        )
        with gzip.open(filename, "rb") as file:
            obj_desc = pickle.load(file)

        config = BNNConfigFactory.get_bnn_config(obj_desc["config"])
        # check here
        state = deepcopy(obj_desc["state"])

        if load_var_init_as_prior:
            state.priors = state.var_init_value

        state.scale_parameters(
            prior_variance_scale=prior_variance_scale,
            init_variance_scale=init_variance_scale,
        )

        return cls(config=config, state=state, **kwargs)

    def check_sizes(self, priors=True):
        if priors:
            params = self.posterior_params
        else:
            params = self.prior_params
        for da, db, item in zip(self.dims, self.dims[1:], params):
            weight_loc, weight_scale, bias_loc, bias_scale = item
            flag = (weight_loc.shape == (da, db)) and (weight_scale.shape == (da, db))
            flag_bias = (bias_loc.shape == (db,)) and (bias_scale.shape == (db,))
            if not (flag and flag_bias):
                raise ValueError(f" incompatible size parameters loaded")
