import torch
from torch.distributions import Distribution
import sys
sys.path.append("..")
from cqc.models import cqc_models as models  # noqa: E402


def uneven_data_gen(gs_0, gs_1, prop_score, x_base_dist: Distribution, y_base_dist: Distribution, n_tot: int = 2000,
                    separate=False, seed=None):
    initial_seed = torch.initial_seed()
    seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
    torch.manual_seed(seed)
    x = x_base_dist.sample((n_tot,))
    probs = prop_score(x)
    # Indicator of whether sample is from treatment group (1)
    zerone = torch.bernoulli(probs)
    y = y_base_dist.sample((n_tot,))
    y = y * (zerone * gs_1[1](x) + (1-zerone) * gs_0[1](x)) + (zerone * gs_1[0](x) + (1-zerone) * gs_0[0](x))
    torch.manual_seed(initial_seed)
    if separate:
        return (y[zerone.bool()], x[zerone.bool()]), (y[~zerone.bool()], x[~zerone.bool()])
    else:
        return (y, x, zerone.float())


def data_gen_matched(gs_0, gs_1, x_base_dist: Distribution, y_base_dist: Distribution, n: int = 1000):
    y0_prex = y_base_dist.sample((n,))
    y1_prex = y_base_dist.sample((n,))
    x = x_base_dist.sample((n,))
    y0 = y0_prex * gs_0[1](x) + gs_0[0](x)
    y1 = y1_prex * gs_1[1](x) + gs_1[0](x)
    return (y0, y1, x)


def get_all_obj(y_base_dist: Distribution, gs_0, gs_1,
                x_base_dist: Distribution, propensity=None):
    """Returns all dowstream distributional objects

    Args:
        y_base_dist (Distribution): The base distribution for y
        gs_0 (List[Callable]): The shift and scale functions for y0 given x
        gs_1 (List[Callable]): The shift and scale functions for y1 given x
        x_base_dist (Distribution): The base distribution for y
        propensity (Callable, optional): The propensity function. Defaults to None.

    Returns:
        cdf_0: The CDF of Y|X,A=0
        cdf_1: The CDF of Y|X,A=0
        icdf_0: The quantile function of Y|X,A=0
        icdf_1: The quantile function of Y|X,A=0
        true_transform: g^*(y|x)
        density_0,: The density of X|A=0
        density_1: The density of X|A=0
    """

    def cdf_0(y, x):
        return y_base_dist.cdf((y-gs_0[0](x))/gs_0[1](x))

    def cdf_1(y, x):
        return y_base_dist.cdf((y-gs_1[0](x))/gs_1[1](x))

    def icdf_0(prob, x):
        return y_base_dist.icdf(prob)*gs_0[1](x)+gs_0[0](x)

    def icdf_1(prob, x):
        return y_base_dist.icdf(prob)*gs_1[1](x)+gs_1[0](x)

    def true_transform(y, x):
        return gs_1[1](x)*(y-gs_0[0](x))/gs_0[1](x)+gs_1[0](x)

    if propensity is not None:
        # Calculate true densities
        normalising_data = x_base_dist.sample((10000,))
        normalising_constant_1 = torch.mean(propensity(normalising_data))
        normalising_constant_0 = 1-normalising_constant_1

        def density_0(x):
            return torch.exp(x_base_dist.log_prob(x)*(1-propensity(x))/normalising_constant_0)

        def density_1(x):
            return torch.exp(x_base_dist.log_prob(x)*propensity(x)/normalising_constant_1)
        return cdf_0, cdf_1, icdf_0, icdf_1, true_transform, density_0, density_1

    return cdf_0, cdf_1, icdf_0, icdf_1, true_transform


def get_all_obj_mlp(y_base_dist: Distribution, x_base_dist: Distribution,
                    gs_0_shift: models.PositiveMLPGenerator, true_cqc: models.PositiveMLPGenerator,
                    propensity=None):
    """Returns all dowstream distributional objects

    Args:
        y_base_dist (Distribution): The base distribution for y
        gs_0 (List[Callable]): The shift and scale functions for y0 given x
        gs_1 (List[Callable]): The shift and scale functions for y1 given x
        x_base_dist (Distribution): The base distribution for y
        propensity (Callable, optional): The propensity function. Defaults to None.

    Returns:
        cdf_0: The CDF of Y|X,A=0
        cdf_1: The CDF of Y|X,A=0
        icdf_0: The quantile function of Y|X,A=0
        icdf_1: The quantile function of Y|X,A=0
        density_0,: The density of X|A=0
        density_1: The density of X|A=0
    """

    def cdf_0(y, x):
        shift, scale = gs_0_shift.get_shift_and_scale(x)
        return y_base_dist.cdf((y-shift) / scale)

    def cdf_1(y, x):
        shift, scale = true_cqc.get_shift_and_scale(x)
        return cdf_0((y-shift)/scale, x)

    def icdf_0(prob, x):
        shift, scale = gs_0_shift.get_shift_and_scale(x)
        return y_base_dist.icdf(prob)*scale+shift

    def icdf_1(prob, x):
        shift, scale = true_cqc.get_shift_and_scale(x)
        return icdf_0(prob, x)*scale+shift

    if propensity is not None:
        # Calculate true densities
        normalising_data = x_base_dist.sample((10000,))
        normalising_constant_1 = torch.mean(propensity(normalising_data))
        normalising_constant_0 = 1-normalising_constant_1

        def density_0(x):
            return torch.exp(x_base_dist.log_prob(x)*(1-propensity(x))/normalising_constant_0)

        def density_1(x):
            return torch.exp(x_base_dist.log_prob(x)*propensity(x)/normalising_constant_1)
        return cdf_0, cdf_1, icdf_0, icdf_1, density_0, density_1

    return cdf_0, cdf_1, icdf_0, icdf_1
