import pyro
import torch
from pyro import distributions as pdist
from pyro.distributions.transforms import ExpTransform
import pickle
import gzip
import os
from sbsep.cbnn import CBNN, AbsModule
from sbsep.config import SBConfig
from collections import ChainMap
import logging

logger = logging.getLogger(__name__)


class SBMixture(AbsModule):
    filetype_suffix = "cbnn_serialized"

    def __init__(self, config: SBConfig, folder):
        """
        :config name:
        """
        super().__init__()

        pickles_folder = folder

        self.config = config
        self.include_space_scaling = not config.modeling.normalized_space
        self.dir_norm_flag = False

        self.name = config.architecture.name

        # function of energy
        if "vec_signal" in config.cbnns:
            self.vec_signal = CBNN(
                config.cbnns["vec_signal"], observe=False, debug=False
            )
        else:
            raise ValueError("vec_signal config missing")
        self.cbnns = [self.vec_signal]

        if self.include_space_scaling:
            # function of space
            # norm : overall magnitude of signal
            if "tsignal_norm" in config.cbnns:
                lconfig = config.cbnns["tsignal_norm"]
                # self.tsignal_norm = CBNN(lconfig, observe=False, debug=False)
                self.tsignal_norm = CBNN.load_from_file(
                    pickles_folder,
                    lconfig.architecture.name,
                    prior_variance_scale=lconfig.varscaling.prior_variance_scale,
                    init_variance_scale=lconfig.varscaling.init_variance_scale,
                    load_var_init_as_prior=True,
                    observe=False,
                    debug=False,
                )
            else:
                raise ValueError("tsignal_norm config missing")

            self.cbnns += [self.tsignal_norm]

        # simplex : StickBreakingTransform
        # simplex-weight : composition signal/ (bg + signal)
        if "dir_weight" in config.cbnns:
            lconfig = config.cbnns["dir_weight"]

            self.dir_weight = CBNN.load_from_file(
                pickles_folder,
                lconfig.architecture.name,
                prior_variance_scale=lconfig.varscaling.prior_variance_scale,
                init_variance_scale=lconfig.varscaling.init_variance_scale,
                load_var_init_as_prior=True,
                observe=False,
                debug=False,
            )
            self.cbnns += [self.dir_weight]
        else:
            raise ValueError("dir_weight config missing")

        # positive : ExpTransform
        # mixture norm : certainty about the composition signal/bg
        if self.dir_norm_flag:
            if "dir_norm" in config.cbnns:
                lconfig = config.cbnns["dir_norm"]
                self.dir_norm = CBNN.load_from_file(
                    pickles_folder,
                    lconfig.architecture.name,
                    prior_variance_scale=lconfig.varscaling.prior_variance_scale,
                    init_variance_scale=lconfig.varscaling.init_variance_scale,
                    load_var_init_as_prior=True,
                    observe=False,
                    debug=False,
                )
                self.cbnns += [self.dir_norm]

    def forward(self, scoord, pcoord, obs_signal=None):
        obs_kl = torch.tensor(1.0e2)
        # ndata x 2
        mu_signal = self.vec_signal(pcoord)

        # ndata x 2
        weight = self.dir_weight(scoord)

        # ndata x 1
        if self.dir_norm_flag:
            dir_norm = self.dir_norm(scoord)
        else:
            dir_norm = torch.tensor(1e2)

        # ndata x 1
        weights = pyro.sample(f"{self.name}#weight", pdist.Dirichlet(weight * dir_norm))

        # sigma = pyro.sample(
        #     f"{self.name}#sigma",
        #     pdist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
        # )
        sigma = torch.tensor(0.002)

        mu = torch.sum(mu_signal * weights, axis=1).unsqueeze(-1)
        if self.include_space_scaling:
            # ndata x 1
            mu = mu*self.tsignal_norm(scoord)



        dist = pdist.TransformedDistribution(
            pdist.LogNormal(mu, sigma), [ExpTransform().inv]
        )
        obs = pyro.sample(f"{self.name}#obs", dist, obs=obs_signal)

        # ts = self.map_observe_transform[False][
        #     self.config.architecture.transform
        # ]
        # mu_trans = reduce(lambda res, f: f(res), ts, mu)

        self.record_error(mu, obs_signal)

        # KL
        bg, signal = mu_signal[..., 0], mu_signal[..., 1]
        kl = torch.sum(signal * torch.log(signal / bg))
        obs_kl = pyro.sample(
            f"{self.name}#obskl", pdist.Normal(kl, torch.tensor(0.1)), obs=obs_kl
        )

    def guide(self, scoord, pcoord, obs_signal=None):
        mu_signal = self.vec_signal.guide(pcoord)
        weights = self.dir_weight.guide(scoord)
        if self.include_space_scaling:
            norm = self.tsignal_norm.guide(scoord)

        if self.dir_norm_flag:
            dir_norm = self.dir_norm.guide(scoord)
        else:
            dir_norm = pyro.param("dir_norm", torch.tensor(1e2))

        weights = pyro.sample(
            f"{self.name}#weight", pdist.Dirichlet(weights * dir_norm)
        )

        # sigma = pyro.sample(
        #     f"{self.name}#sigma",
        #     pdist.InverseGamma(torch.tensor(30.0), torch.tensor(1.45)),
        # )

    def save_state_to_history(self, epoch):
        if (
            self.config.train.save_state_each is not None
            and (epoch % self.config.train.save_state_each) == 0
        ):
            for cbnn in self.cbnns:
                cbnn.save_state_to_history(epoch, override=True)

    @property
    def history(self):
        return dict(ChainMap(*[cbnn.history for cbnn in self.cbnns]))

    def save(self, path=None):
        obj_desc = {"config": self.config.to_json(), "cbnns": {}}
        for c in self.cbnns:
            obj_desc["cbnns"][c.config.architecture.name] = c.save()

        if path is not None:
            filename = os.path.join(
                os.path.expanduser(path), f"{self.name}_{self.filetype_suffix}.pkl.gz"
            )

            with gzip.open(filename, "wb") as file:
                pickle.dump(obj_desc, file)
        else:
            return obj_desc
