import numpy as np
from scipy.stats import norm
from itertools import chain, combinations
import matplotlib.pyplot as plt

def all_subsets(n):
    all_subsets_list = list(chain.from_iterable(combinations(range(n + 1), r) for r in range(n + 1)))
    return all_subsets_list

def complement_operator(set_a, set_b):
    set_a, set_b = set(set_a), set(set_b)
    complement = set_a - set_b
    return tuple(complement)

def n_pdf(x, mean, var):
    probability = norm.pdf(x, loc=mean, scale=(var)**(0.5))
    return probability

def n_cdf(x, mean, var):
    prob = norm.cdf(x, loc=mean, scale=(var)**(0.5))
    return prob

def likelihood_prob(x1):
    if (x1 > 2.25 or x1 < 1.75):
        # prob = 1
        prob = n_pdf(x1, 3, 1)
    else: 
        prob = n_pdf(x1, 0, 1)
    return prob

def marg_prob(x):
    prob = n_pdf(x, 4, 1)*(2*(n_cdf(1.75, 2, 1))) + + (n_pdf(x, 0, 1)*(1 - 2*(n_cdf(1.75, 2, 1))))
    '''
    if x == 4:
        prob = 2*(n_cdf(1.75, 2, 1)) + (n_pdf(x, 0, 1)*(1 - 2*(n_cdf(1.75, 2, 1))))
    else:
        prob = n_pdf(x, 0, 1)*(1 - 2*(n_cdf(1.75, 2, 1)))
    '''
    return prob

def cond_prob(x1, x2):
    # p(x2 | x1)
    likelihood = likelihood_prob(x1)
    # p(x1)
    prior = n_pdf(x1, 0, 1) 
    # p(x2)
    Z = marg_prob(x2)
    return (likelihood*prior)/Z

def generate_sample(n_pairs):
    sample = np.zeros(2*n_pairs)
    x_odd = np.random.normal(2, 1, n_pairs)
    for i, x_i in enumerate(x_odd):
        sample[(2*i)] = x_i
        sample[(2*i + 1)] = sample_x_even(x_i, 1)
    return sample

def sample_x_even(x, n):
    # For even features given i-1 feature is present
    x_even = np.random.normal(0, 1, n)
    indices = (x > 2.25) + (x < 1.75)
    if np.sum(indices) == 0:
        return x_even
    x_special = np.random.normal(4, 1, np.sum(indices))
    x_even[(x > 2.25) + (x < 1.75)] = x_special
    return x_even

def sample(vals, p, n):
    # For odd features
    cdf_values = np.cumsum(p) / np.sum(p)

    # Generate random uniform samples
    uniform_samples = np.random.rand(len(p))

    # Map uniform samples to samples from the PDF using the inverse CDF
    sampled_points = np.interp(uniform_samples, cdf_values, vals)

    return np.random.choice(sampled_points, size = n)

def function(x, weights):
    indicator = (abs(x[:, 1::2]) >= 2)*1.0
    f = (1/np.sum(weights))*(np.dot(indicator,weights))
    return f

def generate_probs(x2):
    # p(x1 | x2) for all x1
    x1 = np.linspace(-4, 4, 1000)
    probs = []
    for val in x1:
        probs.append(cond_prob(val, x2))
    return x1, np.array(probs)

def mask(x, mask_indices, n):
    x_new = np.tile(x, (n,1))
    pair_flag = 0
    for index in mask_indices:
        if pair_flag == 1:
            pair_flag = 0
            continue
        if (index + 1) % 2 != 0:
            if (index + 1) in mask_indices:
                pair_flag = 1
                x_new[:,index] = np.random.normal(2,1,n)
                x_new[:,index + 1] = sample_x_even(x_new[:,index], n)
                continue
            else:
                vals, probs =  generate_probs(x[0][index + 1]) 
                x_new[:,index] = sample(vals, probs, n)
        else:
            x_new[:, index] = sample_x_even(x_new[:, index - 1], n)
    return x_new