import torch 
import torch.nn as nn 
import numpy as np 

from .envs import Set, LogRewardLinear 
from .flows import ForwardFlow, BackwardFlow  

class VariationalApproximationMultiset(nn.Module): 

    def __init__(self, set_size, warehouse_size): 
        super(VariationalApproximationMultiset, self).__init__() 
        self.set_size = set_size 
        self.warehouse_size = warehouse_size 

        # Parameters of the distribution over the objects within the warehouse (sample with replacement) 
        self.phi = nn.Parameter(
            torch.ones((self.warehouse_size,)) / self.warehouse_size, requires_grad=True 
        ) 
        self.cat_warehouse = torch.distributions.Categorical(probs=self.phi) 

    @torch.no_grad() 
    def fit(self, samples): 
        states = samples.state 
        dtype = torch.get_default_dtype() 
        self.phi.data = (states.flatten()[:, None] == torch.arange(self.warehouse_size)[None, :]).to(dtype=dtype).mean(dim=0) 
        self.cat_warehouse = torch.distributions.Categorical(probs=self.phi) 
        return self 
    
    @torch.no_grad() 
    def log_prob(self, samples): 
        return self.cat_warehouse.log_prob(samples.state[:, 1:]).sum(dim=1) 

    @torch.no_grad() 
    def sample(self, num_samples): 
        sample_shape = torch.Size((num_samples, self.set_size)) 
        samples = self.cat_warehouse.sample(sample_shape=sample_shape) 
        # Sort the samples 
        return torch.hstack(
            [torch.ones(num_samples).view(-1, 1) * self.warehouse_size,  
            torch.sort(samples, dim=1).values]   
        ) 

class VariationalProductMultiset(VariationalApproximationMultiset): 

    def __init__(self, variational_approximations): 
        var_apprx = variational_approximations[-1] 
        super(VariationalProductMultiset, self).__init__(var_apprx.set_size, var_apprx.warehouse_size) 
        self.phi.data = torch.zeros((self.warehouse_size,)) 

        for var_apprx in variational_approximations: 
            self.phi.data += torch.log(var_apprx.phi.data) 
        
        self.phi.data -= torch.logsumexp(self.phi.data, dim=0) 
        self.phi.data = self.phi.data.exp() 

        self.cat_warehouse = torch.distributions.Categorical(probs=self.phi) 
9

@torch.no_grad() 
def create_var(set_size=None, warehouse_size=None, **kwargs): 
    return VariationalApproximationMultiset(set_size, warehouse_size) 

@torch.no_grad() 
def create_var_prod(var_apprxs): 
    return VariationalProductMultiset(var_apprxs) 

@torch.no_grad() 
def samples_from_state(state, set_size=None, warehouse_size=None, **kwargs): 
    env = Set(set_size, warehouse_size, batch_size=state.shape[0], log_reward=None) 
    env.state = state.long() 
    return env 
