import os
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import pickle

import libclut32
import clut_sample
from pyclut import build_clut, F_bimodal
import lut_c_wrapper
from generate_distributions import cascade_rounding


def prepare_probs_for_c(probs, precision):
    Z = probs.sum()
    probs /= Z
    probs = np.array(cascade_rounding(probs * precision), dtype=np.uint64)
    assert probs.sum() == precision
    return probs


class cLUTSampler:
    def __init__(self, sampler, precision=32, range_step=1e-3):
        self.sampler = sampler
        self.precision = precision
        self.range_step = 10**(-1 * range_step)

    def get_cLUT(self):
        path = "results/distributions"

        if self.init_table:
            self.generate_cLUT()
            if not os.path.exists(path):
                os.makedirs(path)
            with open(f"{path}/{self.table_name}", "wb") as handle:
                pickle.dump((self.r, self.c, self.cLUT), handle)
        else:
            with open(f"{path}/{self.table_name}", "rb") as handle:
                self.r, self.c, self.cLUT = pickle.load(handle)
            if self.sampler == "lut_c_wrapper":
                self.cLUT_sampler = lut_c_wrapper.PreinitializedCLUTSampler(
                    self.r, self.c, self.cLUT
                )
                self.values = np.arange(-10, 10, self.range_step)[1:]

        print("cLUT generated", self.precision)

    def rvs(self, size, *args):
        if self.sampler == "libclut32int":
            rng = np.random.default_rng()
            indices = libclut32.sample(
                self.cLUT,
                rng.integers(0, int(2**32), size, dtype="uint32"),
                self.r,
                self.c,
                size,
            )
            return self.values[indices]

        elif self.sampler == "libclut32float":
            rng = np.random.default_rng()
            return libclut32.sample_float32(
                self.cLUT,
                rng.integers(0, int(2**32), size, dtype="uint32"),
                self.r,
                self.c,
                size,
            )

        elif self.sampler == "lut_c_wrapper":
            result = clut_sample.sample_cLUT_fast(
                self.cLUT.astype(np.uint32), self.r, self.c, size
            )
            return self.values[result]
        else:
            raise ValueError(f"{self.sampler = } not valid!")

    def generate_cLUT(self):
        values = np.arange(-10, 10, self.range_step)
        cdf_values = self.F(values, **self.F_kwargs)
        probs = np.diff(cdf_values)
        indices = np.arange(len(probs), dtype=np.uint32)
        self.values = values[1:]
        print(self.precision, self.values)

        if self.sampler == "lut_c_wrapper":
            probs = prepare_probs_for_c(
                probs, 
                int(2**self.precision)
            )
            sampler = lut_c_wrapper.CLUTSampler(probs)
            self.r, self.c, self.cLUT = sampler.r, sampler.c, sampler.cLUT
            self.cLUT_sampler = sampler
        elif self.sampler == "libclut32int":
            self.cLUT, _, self.r, self.c = build_clut(
                indices, probs, 
                self.precision
            )
            assert self.r + self.c == self.precision
        else:
            self.cLUT, _, self.r, self.c = build_clut(
                self.values.astype("float32"), 
                probs, 
                self.precision
            )
            assert self.r + self.c == self.precision


class BimodalDistribution:
    def __init__(self, mu1, sigma1, mu2, sigma2, sampler, weight1=0.5, init_table=0, range_step=1e-3):
        self.mu1 = mu1
        self.sigma1 = sigma1
        self.mu2 = mu2
        self.sigma2 = sigma2
        self.weight1 = weight1
        self.weight2 = 1 - weight1
        self.sampler = sampler
        self.range_step = 10**(-1 * range_step)
        self.init_table = init_table
        self.table_name = f"bimodal_numpy_{mu1}_{sigma1}_{mu2}_{sigma2}_{sampler}_{range_step}"

        if self.sampler == "discrete":
            if self.init_table == 1:
                values = np.arange(-10, 10, self.range_step)
                cdf_values = F_bimodal(
                    values, mu1=mu1, sigma1=sigma1, mu2=mu2, sigma2=sigma2
                )
                self.p = np.diff(cdf_values)
                self.values = values[1:]
                with open(f"results/distributions/{self.table_name}", "wb") as handle:
                    pickle.dump((self.p, self.values), handle)
            else:
                with open(f"results/distributions/{self.table_name}", "rb") as handle:
                    self.p, self.values = pickle.load(handle)

    def pdf(self, x):
        return self.weight1 * scipy.stats.norm.pdf(
            x, self.mu1, self.sigma1
        ) + self.weight2 * scipy.stats.norm.pdf(x, self.mu2, self.sigma2)

    def rvs(self, size=1):
        if self.sampler == "continuous":
            rng = np.random.default_rng()
            choices = rng.choice([0, 1], size=size, p=[self.weight1, self.weight2])
            samples = np.where(
                choices == 0,
                rng.normal(self.mu1, self.sigma1, size=size),
                rng.normal(self.mu2, self.sigma2, size=size),
            )

        elif self.sampler == "discrete":
            rng = np.random.default_rng()
            samples = rng.choice(self.values, p=self.p, size=size)
        return samples


class BimodalDistributionFast(cLUTSampler):
    def __init__(
        self, mu1, sigma1, mu2, sigma2, 
        weight1=0.5, sampler="", init_table=False, 
        precision=32, range_step=1e-3
    ):
        self.mu1 = mu1
        self.sigma1 = sigma1
        self.mu2 = mu2
        self.sigma2 = sigma2
        self.weight1 = weight1
        self.weight2 = 1 - weight1

        self.F = F_bimodal
        self.F_kwargs = {"mu1": mu1, "sigma1": sigma1, "mu2": mu2, "sigma2": sigma2}
        self.precision = precision
        self.sampler = sampler
        self.range_step = 10**(-1 * range_step)

        self.init_table = init_table
        self.table_name = f"binomial_{mu1}_{sigma1}_{mu2}_{sigma2}_{sampler}_{range_step}_{precision}"
        self.get_cLUT()

    def pdf(self, x):
        return self.weight1 * scipy.stats.norm.pdf(
            x, self.mu1, self.sigma1
        ) + self.weight2 * scipy.stats.norm.pdf(x, self.mu2, self.sigma2)


class LognormalDistribution:
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma
        self.scale = np.exp(mu)

    def pdf(self, x):
        return scipy.stats.lognorm.pdf(x, s=self.sigma, scale=self.scale)

    def rvs(self, size=1):
        return scipy.stats.lognorm.rvs(s=self.sigma, scale=self.scale, size=size)


class LognormalDistributionFast(cLUTSampler):
    def __init__(
            self, mu, sigma, sampler="", 
            init_table=False, 
            precision=32, range_step=1e-3
        ):
        self.mu = mu
        self.sigma = sigma
        self.scale = np.exp(mu)
        self.sampler = sampler
        self.F = scipy.stats.lognorm.cdf
        self.F_kwargs = {"s": self.sigma, "scale": np.exp(self.mu)}
        self.init_table = init_table
        self.table_name = f"binomial_{mu}_{sigma}_{sampler}_{range_step}_{precision}"
        self.precision = precision
        self.range_step = 10**(-1 * range_step)
        self.get_cLUT()

    def pdf(self, x):
        return scipy.stats.lognorm.pdf(
            x, s=self.sigma,
            scale=self.scale
        )


class NormalDistribution:
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma

    def pdf(self, x):
        return scipy.stats.norm.pdf(x, loc=self.mu, scale=self.sigma)

    def rvs(self, size=1):
        return scipy.stats.lognorm.rvs(loc=self.mu, scale=self.sigma, size=size)


class NormalDistributionFast(cLUTSampler):
    def __init__(
            self, mu, sigma, sampler="", 
            init_table=False, 
            precision=32, range_step=1e-3
        ):
        self.mu = mu
        self.sigma = sigma
        self.sampler = sampler
        self.F = scipy.stats.norm.cdf
        self.F_kwargs = {"loc": self.mu, "scale": self.sigma}
        self.init_table = init_table
        self.table_name = f"binomial_{mu}_{sigma}_{sampler}_{range_step}_{precision}"
        self.precision = precision
        self.range_step = 10**(-1 * range_step)
        self.get_cLUT()

    def pdf(self, x):
        return scipy.stats.norm.pdf(x, loc=self.mu, scale=self.sigma)

class TrueSkill:
    def __init__(
        self,
        proposal_skill1=scipy.stats.norm(0, 1),
        proposal_skill2=scipy.stats.norm(0, 1),
        beta=1,
        name="",
    ):
        self.proposal_skill1 = proposal_skill1
        self.proposal_skill2 = proposal_skill2
        self.proposal_performance1 = scipy.stats.norm(0, np.sqrt(1 + beta**2))
        self.proposal_performance2 = scipy.stats.norm(0, np.sqrt(1 + beta**2))
        self.beta = beta
        self.name = name

    def sample_with_outcome(
        self, n=1000000, sample_only: bool = False, plot_results: bool = False
    ):
        s1 = self.proposal_skill1.rvs(size=n)
        s2 = self.proposal_skill2.rvs(size=n)
        p1 = self.proposal_performance1.rvs(size=n)
        p2 = self.proposal_performance2.rvs(size=n)
        if sample_only:
            return

        f1 = self.proposal_skill1.pdf(s1)
        f2 = self.proposal_skill1.pdf(s2)
        f3 = scipy.stats.norm(s1, self.beta).pdf(p1)
        f4 = scipy.stats.norm(s2, self.beta).pdf(p2)
        f5 = p1 > p2

        g1 = self.proposal_skill1.pdf(s1)
        g2 = self.proposal_skill2.pdf(s2)
        g3 = self.proposal_performance1.pdf(p1)
        g4 = self.proposal_performance2.pdf(p2)

        weights = (f1 * f2 * f3 * f4 * f5) / (g1 * g2 * g3 * g4)

        results = {
            "skill1": s1,
            "skill2": s2,
            "performance1": p1,
            "performance2": p2,
            "weights": weights,
        }

        if plot_results == 1:
            self.plot_result(results)
            print("save")
            with open(f"data/{self.name}.pickle", "wb") as handle:
                pickle.dump(results, handle)

        return results

    def plot_result(self, results, path="results/trueskill"):
        if not os.path.exists(path):
            os.makedirs(path)

        colors = ["red", "blue"]

        names_dict = {"baseline": "baseline", "clut": "cLUT"}
        method = names_dict[self.name.split("_")[0]]

        fig, axs = plt.subplots(1, 2, figsize=(8, 4))
        axs = axs.flatten()

        axs[1].hist(
            [results["skill1"], results["skill2"]],
            weights=[results["weights"], results["weights"]],
            label=[r"$p(s_1)$", r"$p(s_2)$"],
            bins=50,
            color=colors,
        )
        axs[1].set_title("Skills posterior distributions")
        axs[1].legend()

        axs[0].hist(
            [results["skill1"], results["skill2"]],
            label=[r"$p(s_1)$", r"$p(s_2)$"],
            bins=50,
            color=colors,
        )
        axs[0].set_title("Skills prior distributions")
        axs[0].legend()

        fig.tight_layout()
        plt.savefig(f"results/trueskill/skills_results_{self.name}.png", dpi=300)
        plt.close(fig)

        fig, axs = plt.subplots(1, 2, figsize=(8, 4))
        axs = axs.flatten()

        axs[1].hist(
            [results["performance1"], results["performance2"]],
            weights=[results["weights"], results["weights"]],
            label=["p(p1)", "p(p2)"],
            bins=50,
            color=colors,
        )
        axs[1].set_title("Performances posterior distributions")
        axs[1].legend()
        axs[0].hist(
            [results["performance1"], results["performance2"]],
            label=["p(p1)", "p(p2)"],
            bins=50,
            color=colors,
        )
        axs[0].set_title("Performances prior distributions")
        axs[0].legend()

        fig.tight_layout()
        plt.savefig(f"results/trueskill/performances_results_{self.name}.png", dpi=300)
        plt.close(fig)
