import torch 
import torch.nn as nn 
import numpy as np 

from .envs import Sequence, LogRewardLinear 
from .flows import ForwardFlow, BackwardFlow 

class VariationalApproximationSequence(nn.Module): 

    def __init__(self, max_size, vocab_size): 
        super(VariationalApproximationSequence, self).__init__() 
        self.max_size = max_size 
        self.vocab_size = vocab_size 

        # For each size, generate a distribution over the sequences' elements 

        # The model is: size ~ Cat(theta); elements ~ Cat(phi | size) 
        self.theta = nn.Parameter(torch.ones((self.max_size + 1,)) / self.max_size, requires_grad=True) 
        self.phi = nn.ParameterList(
            [torch.ones((size, vocab_size + 1)) / self.vocab_size for size in range(1, max_size + 1)] 
        )

        self.vocab = torch.arange(vocab_size + 1) 

    @torch.no_grad() 
    def fit(self, samples): 
        dtype = torch.get_default_dtype() 
        self.theta.data = (samples.lengths[:, None] == (torch.arange(self.max_size) + 1)[None, :]).to(dtype=dtype).mean(dim=0) 
        for length in torch.unique(samples.lengths.long()): 
            indices = (samples.lengths == length) 
            states = samples.state[indices, :length]
            # (N, L), 
            probs = (states[:, :, None] == self.vocab.view(1, 1, -1)).to(dtype=dtype).sum(dim=0) / sum(indices) 
            assert torch.isclose(probs.sum(dim=-1), torch.ones_like(probs.sum(dim=-1))).all(), (probs, probs.sum(dim=-1), probs.sum(dim=-1) == 1., states)  
            self.phi[length - 1] = probs
        return self  

    @torch.no_grad() 
    def log_prob(self, samples): 
        size_dist = torch.distributions.Categorical(self.theta)
        logprob = torch.zeros((samples.batch_size,)) 
        for size in range(1, self.max_size + 1): 
            indices = (samples.lengths == size) 
            states = samples.state[indices, :size] 
            states_dist = torch.distributions.Categorical(self.phi[size - 1]) 
            # print(self.phi[size - 1].shape, size, sum(indices), states.shape)
            logprob_samples = states_dist.log_prob(states).sum(dim=1) 
            logprob[indices] += logprob_samples 
            logprob[indices] += size_dist.log_prob(torch.tensor(size - 1))  
        return logprob 

    @torch.no_grad() 
    def sample(self, num_samples): 
        sample_shape = torch.Size((num_samples,)) 
        size_dist = torch.distributions.Categorical(self.theta) 
        sizes = size_dist.sample(sample_shape=sample_shape) 
        samples = torch.ones((0, self.max_size + 1)) 
        for size in torch.unique(sizes): 
            num_samples_size = (sizes == size).sum() 
            sample_shape_size = torch.Size((num_samples_size,)) 
            samples_dist = torch.distributions.Categorical(self.phi[size]) 
            samples_size = samples_dist.sample(sample_shape=sample_shape_size) 
            # Padding 
            padding = torch.zeros((samples_size.shape[0], self.max_size + 1 - samples_size.shape[1])) 
            samples_size = torch.hstack([samples_size, padding])  
            samples = torch.vstack([samples, samples_size]) 
        return samples 

class VariationalProductSequence(VariationalApproximationSequence): 

    @torch.no_grad() 
    def __init__(self, variational_approximations):
        var_apprx = variational_approximations[-1] 
        super(VariationalProductSequence, self).__init__(var_apprx.max_size, var_apprx.vocab_size) 
        self.theta.data = torch.zeros((self.max_size,)) 
        self.phi = nn.ParameterList(
            [torch.zeros((size, self.vocab_size + 1)) for size in range(1, self.max_size + 1)]
        ) 

        for var_apprx in variational_approximations: 
            self.theta.data += torch.log(var_apprx.theta.data) 
            for i, _ in enumerate(self.phi): 
                self.phi[i] += torch.log(var_apprx.phi[i]) 
            
        self.theta -= torch.logsumexp(self.theta.data, dim=0) 
        self.theta.data = self.theta.data.exp()  
        for i, _ in enumerate(self.phi): 
            self.phi[i] -= torch.logsumexp(self.phi[i], dim=0) 
            self.phi[i] = self.phi[i].exp() 

@torch.no_grad() 
def create_var(max_size=None, vocab_size=None, **kwargs): 
    return VariationalApproximationSequence(max_size, vocab_size) 

@torch.no_grad() 
def create_var_prod(var_apprxs):
    return VariationalProductSequence(var_apprxs) 

@torch.no_grad() 
def samples_from_state(state, max_size=None, vocab_size=None, **kwargs): 
    sequences = Sequence(max_size=max_size, vocab_size=vocab_size, batch_size=state.shape[0], log_reward=None)
    sequences.state = state.long() 
    sequences.mask_padding = (state != 0.) 
    sequences.mask_padding[:, 0] = 1. 
    sequences.size = max_size + 1 
    return sequences 
