from pomegranate.distributions import Categorical
from pomegranate.distributions import ConditionalCategorical
from pomegranate.bayesian_network import BayesianNetwork
import torch
import numpy as np
import random
import torch.nn.functional as F
from pomegranate.distributions import JointCategorical
from pomegranate.factor_graph import FactorGraph


def get_tree4_test(mu=0.1, num_var=15):

    def get_prob_dist_binary_test(mu=0.1):
    
        p = torch.FloatTensor(1).uniform_(0, 1).item()
        lst = [p, 1 - p]
        random.shuffle(lst)
        return lst
    # 15 vars
  
    model = BayesianNetwork()
    torch.manual_seed(1000)
    random.seed(1000)
    vars = [Categorical([get_prob_dist_binary()])]
    for i in range(14):
        torch.manual_seed(i)
        random.seed(i)
        vars.append(ConditionalCategorical([[get_prob_dist_binary(), get_prob_dist_binary()]]))
  
    model.add_distributions(vars)

    model.add_edge(vars[0], vars[1])
    model.add_edge(vars[0], vars[2])
    model.add_edge(vars[1], vars[3])
    model.add_edge(vars[1], vars[4])
    model.add_edge(vars[2], vars[5])
    model.add_edge(vars[2], vars[6])
  
    model.add_edge(vars[3], vars[7])
    model.add_edge(vars[3], vars[8])
    model.add_edge(vars[4], vars[9])
    model.add_edge(vars[4], vars[10])

    model.add_edge(vars[5], vars[11])
    model.add_edge(vars[5], vars[12])
    model.add_edge(vars[6], vars[13])
    model.add_edge(vars[6], vars[14])
  
    graph = Tree4(vars, model)
    return graph

def get_tree4():

    model = BayesianNetwork()
    vars = [Categorical([get_prob_dist_binary()])]
    for i in range(14):
        vars.append(ConditionalCategorical([[get_prob_dist_binary(), get_prob_dist_binary()]]))
    model.add_distributions(vars)

    model.add_edge(vars[0], vars[1])
    model.add_edge(vars[0], vars[2])
    model.add_edge(vars[1], vars[3])
    model.add_edge(vars[1], vars[4])
    model.add_edge(vars[2], vars[5])
    model.add_edge(vars[2], vars[6])
  
    model.add_edge(vars[3], vars[7])
    model.add_edge(vars[3], vars[8])
    model.add_edge(vars[4], vars[9])
    model.add_edge(vars[4], vars[10])

    model.add_edge(vars[5], vars[11])
    model.add_edge(vars[5], vars[12])
    model.add_edge(vars[6], vars[13])
    model.add_edge(vars[6], vars[14])
  
    graph = Tree4(vars, model)
    return graph


class Tree4:
    def __init__(self, vars, graph):

        self.vars = vars
        self.num_var = len(vars)
        self.graph = graph
        self.probs = [self.vars[i].probs[0] for i in range(self.num_var)]

        self.root = [0]
        self.l1_leaf = [1, 2]
        self.l2_leaf = [3, 4, 5, 6]
        self.l3_leaf = [7, 8, 9, 10, 11, 12, 13, 14]
  
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable=6)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_var(self, x, idx):
        # mask the idx-th variable and the future ones
        # input: (n, num_var, 2) has to be one hot already
        x = x.view(x.size(0), -1)

        if idx in self.root:
            x *= 0
        elif idx in self.l1_leaf:
            x[:, 1:] = 0
        elif idx in self.l2_leaf:
            x[:, 3:] = 0
        elif idx in self.l3_leaf:
            x[:, 7:] = 0

    # 0: (01), 1(23), 2(45), 3(67), 4(89), 5(1011), 6(
      
        else:
            raise Exception("Mask idx not exist")
          
        pos = torch.zeros(self.num_var).unsqueeze(0)
        pos[: ,idx] = 1
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)
  
    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def get_prob(self, idx):
        return self.probs[idx]



def get_tree3_test(mu=0.1, num_var=10):

    def get_prob_dist_binary_test(mu=0.1):
    
        p = torch.FloatTensor(1).uniform_(0, 1).item()
        lst = [p, 1 - p]
        random.shuffle(lst)
        return lst
      
    model = BayesianNetwork()
    torch.manual_seed(42)
    random.seed(42)
    vars = [Categorical([get_prob_dist_binary()])]
    for i in range(6):
        torch.manual_seed(i)
        random.seed(i)
        vars.append(ConditionalCategorical([[get_prob_dist_binary(), get_prob_dist_binary()]]))
    model.add_distributions(vars)

    model.add_edge(vars[0], vars[1])
    model.add_edge(vars[0], vars[2])
    model.add_edge(vars[1], vars[3])
    model.add_edge(vars[1], vars[4])
    model.add_edge(vars[2], vars[5])
    model.add_edge(vars[2], vars[6])
    graph = Tree3(vars, model)
    return graph

def get_tree3():

    model = BayesianNetwork()
    
    vars = [Categorical([get_prob_dist_binary()])]
    for i in range(6):
        vars.append(ConditionalCategorical([[get_prob_dist_binary(), get_prob_dist_binary()]]))
    model.add_distributions(vars)

    model.add_edge(vars[0], vars[1])
    model.add_edge(vars[0], vars[2])
    model.add_edge(vars[1], vars[3])
    model.add_edge(vars[1], vars[4])
    model.add_edge(vars[2], vars[5])
    model.add_edge(vars[2], vars[6])
    graph = Tree3(vars, model)
    return graph


class Tree3:
    def __init__(self, vars, graph):

        self.vars = vars
        self.num_var = len(vars)
        self.graph = graph
        self.probs = [self.vars[i].probs[0] for i in range(self.num_var)]

        self.root = [0]
        self.l1_leaf = [1, 2]
        self.l2_leaf = [3, 4, 5, 6]
  
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable=6)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_var(self, x, idx):
        # mask the idx-th variable and the future ones
        # input: (n, num_var, 2) has to be one hot already
        x = x.view(x.size(0), -1)

        if idx in self.root:
            x *= 0
        elif idx in self.l1_leaf:
            x[:, 1:] = 0
        elif idx in self.l2_leaf:
            x[:, 3:] = 0
        else:
            raise Exception("Mask idx not exist")
          
        pos = torch.zeros(self.num_var).unsqueeze(0)
        pos[: ,idx] = 1
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)
  
    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def get_prob(self, idx):
        return self.probs[idx]



def get_prob_dist_binary(mu=0.1):

    p = torch.FloatTensor(1).uniform_(0.15, 0.3).item()
    lst = [p, 1 - p]
    random.shuffle(lst)
    return lst

def get_markov_chain(mu=0.1, num_var=10):

    model = BayesianNetwork()
    vars = [Categorical([get_prob_dist_binary()])]
    for _ in range(num_var-1):
        vars.append(ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]]))
    
    model.add_distributions(vars)
    for i in range(num_var-1):
        model.add_edge(vars[i], vars[i+1])

    graph = MarkovChain(vars, model)

    return graph

def get_markov_chain_test(mu=0.1, num_var=10):

    def get_prob_dist_binary_test(mu=0.1):
    
        p = torch.FloatTensor(1).uniform_(0, 1).item()
        lst = [p, 1 - p]
        random.shuffle(lst)
        return lst
  
    model = BayesianNetwork()
    torch.manual_seed(42)
    random.seed(42)
    vars = [Categorical([get_prob_dist_binary_test()])]
    for i in range(num_var-1):
        torch.manual_seed(i)
        random.seed(i)
        vars.append(ConditionalCategorical(
        [[get_prob_dist_binary_test(), get_prob_dist_binary_test()]]))
    
    model.add_distributions(vars)
    for i in range(num_var - 1):
        model.add_edge(vars[i], vars[i+1])

    graph = MarkovChain(vars, model)

    return graph


class MarkovChain:
    def __init__(self, vars, graph):

        self.vars = vars
        self.num_var = len(vars)
        self.graph = graph
        self.probs = [self.vars[i].probs[0] for i in range(self.num_var)]
      
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable=6)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_var(self, x, idx):
        # mask the idx-th variable and the future ones
        # input: (n, num_var, 2) has to be one hot already
        x = x.view(x.size(0), -1)
        x[:, 2*idx:] = 0
        pos = torch.zeros(self.num_var).unsqueeze(0)
        pos[: ,idx] = 1
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)
  
    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def get_prob(self, idx):
        return self.probs[idx]



def get_general_network_test(mu=0.1):

    model = BayesianNetwork()
    A = Categorical([[0.2, 0.8]])
    B = Categorical([[0.72, 0.28]])

    C = ConditionalCategorical([[[[0.17, 0.83], [0.75, 0.25]], [
                                       [0.21, 0.79], [0.27, 0.73]]]])
    # D = ConditionalCategorical(
    #     [[[0.6, 0.4], [0.3, 0.7]]])
    
    E = ConditionalCategorical([[[[0.3, 0.7], [0.16, 0.84]], [
                                       [0.9, 0.1], [0.2, 0.8]]]])
    
    F = ConditionalCategorical(
        [[[0.22, 0.78], [0.7, 0.3]]])
    vars = [A, B, C, E, F]
    model.add_distributions([A, B, C, E, F])
    model.add_edge(A, C)
    model.add_edge(B, C)
    model.add_edge(C, E)
    model.add_edge(A, E)
    model.add_edge(E, F)
    graph = GeneralGraph(vars, model)
    return graph


def get_general_network(mu=0.1):

    model = BayesianNetwork()
    A = Categorical([get_prob_dist_binary()])
    B = Categorical([get_prob_dist_binary()])
    
    C = ConditionalCategorical([[[get_prob_dist_binary(), get_prob_dist_binary()], [
                                       get_prob_dist_binary(), get_prob_dist_binary()]]])
    D = ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]])
    
    E = ConditionalCategorical([[[get_prob_dist_binary(), get_prob_dist_binary()], [
                                       get_prob_dist_binary(), get_prob_dist_binary()]]])
    
    F = ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]])
    vars = [A, B, C, E, F]
    model.add_distributions([A, B, C, E, F])
    model.add_edge(A, C)
    model.add_edge(B, C)
    model.add_edge(C, E)
    model.add_edge(A, E)
    model.add_edge(E, F)
    graph = GeneralGraph(vars, model)

    return graph

class GeneralGraph:
    def __init__(self, vars, graph):

        self.vars = vars
        self.num_var = len(vars)
        self.probs = [self.vars[i].probs[0] for i in range(self.num_var)]
        self.graph = graph
    
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable=6)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_var(self, x, idx):
        # mask the idx-th variable and the future ones
        # input: (n, num_var, 2) has to be one hot already
        x = x.view(x.size(0), -1)
        if idx == 0 or idx == 1: # A or B var
            x *= 0
        elif idx == 2: # C var
            x[:, 2:] = 0
        elif idx == 3: # D var
            x[:, 3:] = 0
        elif idx == 4: # E var
            x[:, 4:] = 0
      
        pos = torch.zeros(self.num_var).unsqueeze(0)
        pos[: ,idx] = 1
        pos = pos.repeat(x.size(0), 1)
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)
  
    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def get_prob(self, idx):
        return self.probs[idx]




def get_general2_network_test(mu=0.1):

    model = BayesianNetwork()
    A = Categorical([[0.2, 0.8]])
    B = Categorical([[0.72, 0.28]])

    C = ConditionalCategorical(
        [[[0.72, 0.28], [0.3, 0.7]]])
    D = ConditionalCategorical([[[[0.17, 0.83], [0.75, 0.25]], [
                                       [0.21, 0.79], [0.27, 0.73]]]])

    E = ConditionalCategorical([[[[0.3, 0.7], [0.16, 0.84]], [
                                       [0.9, 0.1], [0.2, 0.8]]]])
    
    vars = [A, B, C, D, E]
    model.add_distributions([A, B, C, D, E])
    model.add_edge(A, C)
    model.add_edge(B, D)
    model.add_edge(C, D)
    model.add_edge(B, E)
    model.add_edge(D, E)
    graph = GeneralGraph2(vars, model)
    return graph


def get_general2_network(mu=0.1):

    model = BayesianNetwork()
    A = Categorical([get_prob_dist_binary()])
    B = Categorical([get_prob_dist_binary()])
    C = ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]])

    D = ConditionalCategorical([[[get_prob_dist_binary(), get_prob_dist_binary()], [
                                       get_prob_dist_binary(), get_prob_dist_binary()]]])
    
    E = ConditionalCategorical([[[get_prob_dist_binary(), get_prob_dist_binary()], [
                                       get_prob_dist_binary(), get_prob_dist_binary()]]])
  
    vars = [A, B, C, D, E]
    model.add_distributions([A, B, C, D, E])
    model.add_edge(A, C)
    model.add_edge(B, D)
    model.add_edge(C, D)
    model.add_edge(B, E)
    model.add_edge(D, E)
    graph = GeneralGraph2(vars, model)
    

    return graph

class GeneralGraph2:
    def __init__(self, vars, graph):

        self.vars = vars
        self.num_var = len(vars)
        self.probs = [self.vars[i].probs[0] for i in range(self.num_var)]
        self.graph = graph
    
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable=6)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_var(self, x, idx):
        # mask the idx-th variable and the future ones
        # input: (n, num_var, 2) has to be one hot already
        x = x.view(x.size(0), -1)
        if idx == 0 or idx == 1: # A or B var
            x *= 0
        elif idx == 2: # C var
            x[:, 2:] = 0
        elif idx == 3: # D var
            x[:, 3:] = 0
        elif idx == 4: # E var
            x[:, 4:] = 0
      
        pos = torch.zeros(self.num_var).unsqueeze(0)
        pos[: ,idx] = 1
        pos = pos.repeat(x.size(0), 1)
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)
  
    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def get_prob(self, idx):
        return self.probs[idx]




def get_wet_grass_network(mu=0.1):

    cloudy = Categorical([get_prob_dist_binary()])

    rain = ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]])
    sprinkler = ConditionalCategorical(
        [[get_prob_dist_binary(), get_prob_dist_binary()]])

    wet_grass = ConditionalCategorical([[[get_prob_dist_binary(), get_prob_dist_binary()], [
                                       get_prob_dist_binary(), get_prob_dist_binary()]]])

    model = BayesianNetwork()
    model.add_distributions([cloudy, rain, sprinkler, wet_grass])
    model.add_edge(cloudy, rain)
    model.add_edge(cloudy, sprinkler)

    model.add_edge(rain, wet_grass)
    model.add_edge(sprinkler, wet_grass)

    graph = WetGrass(cloudy, rain, sprinkler, wet_grass, model)

    return graph


def get_wet_grass_network_test():

    cloudy = Categorical([[0.2, 0.8]])

    rain = ConditionalCategorical(
        [[[0.415, 1-0.415], [0.91, 0.09]]])
    sprinkler = ConditionalCategorical(
        [[[0.78, 0.22], [0.45, 0.55]]])

    wet_grass = ConditionalCategorical([[[[0.63, 0.37], [0.22, 0.78]], [
                                       [0.1, 0.9], [0.2, 0.8]]]])

    model = BayesianNetwork()
    model.add_distributions([cloudy, rain, sprinkler, wet_grass])
    model.add_edge(cloudy, rain)
    model.add_edge(cloudy, sprinkler)

    model.add_edge(rain, wet_grass)
    model.add_edge(sprinkler, wet_grass)

    graph = WetGrass(cloudy, rain, sprinkler, wet_grass, model)

    return graph


class WetGrass:
    def __init__(self, cloudy, rain, sprinkler, wet_grass, graph):

        self.cloudy = cloudy
        self.rain = rain
        self.sprinkler = sprinkler
        self.wet_grass = wet_grass
        self.graph = graph
    
    def sample(self, n):
        return self.graph.sample(n) # (n, num_variable)
    
    def to_onehot(self, sample):
        # input: (n, num_variable)
        # output: (n, num_variable, 2)
        return F.one_hot(sample.to(torch.int64), num_classes=2)

    def mask_cloudy(self, x):
        # input: (n, 8) has to be one hot already
        # output: (n, num_variable, 2)
        x = x*0
        pos = torch.tensor([1, 0, 0, 0]).unsqueeze(0)
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)

    def mask_rain(self, x):        
        # input: (n, 8) has to be one hot already
        # output: (n, num_variable, 2)
        x[:, 1:] = 0
        pos = torch.tensor([0, 1, 0, 0]).unsqueeze(0)
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)

    def mask_sprinkler(self, x):
        # input: (n, 8) has to be one hot already
        # output: (n, 12)
        x[:, 2:] = 0
        pos = torch.tensor([0, 0, 1, 0]).unsqueeze(0)
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)

    def mask_wet_grass(self, x):
        # input: (n, 8) has to be one hot already
        # output: (n, num_variable, 2)
        x[:, -1] = 0
        pos = torch.tensor([0, 0, 0, 1]).unsqueeze(0)
        pos = pos.repeat(x.size(0), 1)
        
        return torch.cat([x.view(x.size(0), -1), pos], dim=-1)

    def shrink_from_onehot(self, x):
        # input: (n, num_variable*2) has to be one hot already
        # output: (n, num_variable)
        x = x.view(x.size(0), -1, 2)
        x = x.argmax(-1)

    def travel_graph(self, x):
        # input: (1, ??) can be any 1D size
        cloudy_x = x[:, 0]
                    
    def get_cloudy(self):
        return self.cloudy.probs[0]

    def get_rain(self):
        return self.rain.probs[0]

    def get_sprinkler(self):
        return self.sprinkler.probs[0]
    
    def get_wet_grass(self):
        # print(self.wet_grass.probs.size())
        return self.wet_grass.probs[0]