import torch 
import torch.nn as nn 
import numpy as np 

from .envs import Grid2D, LogRewardDist 
from .flows import ForwardFlow, BackwardFlow

class VariationalApproximationGrid(nn.Module): 

    def __init__(self, width, height): 
        super(VariationalApproximationGrid, self).__init__() 
        self.width = width 
        self.height = height 

        self.phi = nn.ParameterList( 
                [torch.ones(self.width + 1, requires_grad=True) / (self.width + 1), 
                torch.ones(self.height + 1, requires_grad=True) / (self.height + 1)]
            ) 

        self.x_support = torch.arange(self.width + 1) 
        self.y_support = torch.arange(self.height + 1) 
        
        self.x_dist = torch.distributions.Categorical(probs=self.phi[0]) 
        self.y_dist = torch.distributions.Categorical(probs=self.phi[1]) 
    
    # (N, D) 
    # (N, D, 1) == (1, 1, W) 
    @torch.no_grad() 
    def fit(self, samples): 
        pos = samples.pos 
        default_dtype = torch.get_default_dtype() 
        x_freqs = (pos[:, 0, None] == self.x_support[None, :]).to(dtype=default_dtype).mean(dim=0) 
        y_freqs = (pos[:, 1, None] == self.y_support[None, :]).to(dtype=default_dtype).mean(dim=0) 

        self.phi[0] = x_freqs 
        self.phi[1] = y_freqs 

        # Update the variational family with the estimated parameters 
        self.x_dist = torch.distributions.Categorical(probs=self.phi[0]) 
        self.y_dist = torch.distributions.Categorical(probs=self.phi[1]) 
        return self 

    @torch.no_grad() 
    def log_prob(self, samples): 
        pos = samples.pos 
        return self.x_dist.log_prob(pos[:, 0]) + self.y_dist.log_prob(pos[:, 1]) 

    @torch.no_grad() 
    def sample(self, num_samples): 
        sample_shape = torch.Size((num_samples,)) 
        xs = self.x_dist.sample(sample_shape) 
        ys = self.y_dist.sample(sample_shape) 
        return torch.vstack([xs, ys]).t() 

class VariationalProductGrid(VariationalApproximationGrid): 

    def __init__(self, variational_approximations): 
        var_apprx = variational_approximations[-1]
        super(VariationalProductGrid, self).__init__(var_apprx.width, var_apprx.height) 
        x_dist = torch.zeros(self.width + 1) 
        y_dist = torch.zeros(self.height + 1) 

        for var_apprx in variational_approximations: 
            x_dist += torch.log(var_apprx.phi[0])  
            y_dist += torch.log(var_apprx.phi[1]) 

        x_dist -= torch.logsumexp(x_dist, dim=0) 
        y_dist -= torch.logsumexp(y_dist, dim=0) 
        
        self.phi = nn.ParameterList([x_dist.exp(), y_dist.exp()]) 

        self.x_dist = torch.distributions.Categorical(probs=self.phi[0]) 
        self.y_dist = torch.distributions.Categorical(probs=self.phi[1]) 

@torch.no_grad() 
def create_var(width=None, height=None, **kwargs): 
    return VariationalApproximationGrid(width, height) 

@torch.no_grad() 
def create_var_prod(var_apprxs):
    return VariationalProductGrid(var_apprxs) 

@torch.no_grad() 
def samples_from_state(state, width=None, height=None, **kwargs): 
    grid = Grid2D(width, height, log_reward=None, batch_size=state.shape[0]) 
    grid.pos = state 
    return grid 