import torch
from torch import nn
import torch.nn.functional as F

from models import VariableMLP


class TrivialPredictor(nn.Module):
    def __init__(self, Z_dim=None, X_dim=None):
        super().__init__()
        self.Z_dim=Z_dim
        self.X_dim=X_dim
        self.pred=nn.Parameter(torch.tensor(0.5))

    def eval_seq(self, Zkwargs, X, Y, return_preds=False):
        # predict 1/2 everywhere
        p_hat_pred = torch.ones_like(Y) * self.pred
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y, 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred
        return loss_matrix


class MarginalPredictorContext(nn.Module):
    
    def __init__(self, X_dim=None, Z_dim=None, MLP_layer=6, MLP_width=50, prior_scale=0):
        super().__init__()

        self.MLP_width = MLP_width
        self.MLP_layer = MLP_layer
        self.prior_scale = prior_scale
        assert prior_scale == 0 # not implemented here but argument is passed for other models
        if Z_dim is not None or X_dim is not None:
            self.z_encoder = lambda x: x
            self.z_encoder_output_dim = Z_dim
        else:
            raise ValueError("Need a Z feature")

        self.X_dim = X_dim
        self.top_layer = VariableMLP(input_dim=self.z_encoder_output_dim+self.X_dim,
                             num_layers=MLP_layer, width=MLP_width)
    
    def forward(self, Zkwargs, X):
        embed = self.z_encoder(Zkwargs)
         
        bs, ncol, xdim = X.shape
        bsZ, zdim = embed.shape
        assert bs == bsZ
        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))

        input_ = torch.cat([embedZ,X], 2)
        return self.top_layer(input_).squeeze(2)

    def get_features(self, Zkwargs, X):
        embed = self.z_encoder(Zkwargs)
        
        bs, ncol, xdim = X.shape
        bsZ, zdim = embed.shape
        assert bs == bsZ
        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))

        input_ = torch.cat([embedZ,X], 2)

        partial_top_layer = nn.Sequential(*list(self.top_layer.model.children())[:5])
        features_of_X = partial_top_layer(input_)
        return features_of_X

    def eval_seq(self, Zkwargs, X, Y, N=None, return_preds=False, trainlen=None, exact=False):
        if N is None:
            N = Y.shape[1]
        N = min(N, Y.shape[1])
        
        # not a sequence model (marginal loss)
        p_hat_pred = self.forward(Zkwargs, X)
        # true success_p and p_hat_pred have dimensions (rows, N)
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y[:,:N], 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred
        return loss_matrix

class SequentialPredictorContext(nn.Module):
    def __init__(self, bert_encoder=None, X_dim=None, Z_dim=None, MLP_layer=6, 
                 MLP_width=50, init_mean=0, repeat_suffstat=100, prior_scale=0, X_MLP_layer=0, X_MLP_width=0, 
                 suffstat_eps = 0.01):
        super().__init__()
        self.init_mean = init_mean
        self.MLP_layer = MLP_layer
        self.MLP_width = MLP_width

        self.repeat_suffstat = repeat_suffstat
        self.X_dim = X_dim
        
        self.X_MLP_layer = MLP_layer
        self.X_MLP_width = MLP_width
        
        assert prior_scale==0 # not implemented here but argument is passed for other models

        self.suffstat_eps = suffstat_eps

        if Z_dim is not None:
            self.z_encoder = lambda x: x
            self.z_encoder_output_dim = Z_dim

        else:
            # There are no Z features
            self.z_encoder = lambda x: None
            self.z_encoder_output_dim = 0

        if self.X_MLP_layer == 0 or self.X_MLP_width == 0:
            self.x_suff_encoder = lambda x: x
        else:
            self.x_suff_encoder = VariableMLP(input_dim=X_dim, 
                    output_dim=X_dim,
                    num_layers=self.X_MLP_layer,
                    width=self.X_MLP_width)
            
        # size of suff statistic
        self.suffStatDim = X_dim + X_dim**2

        self.top_layer = VariableMLP(input_dim=self.z_encoder_output_dim + self.X_dim + self.suffStatDim*repeat_suffstat,
                             num_layers=MLP_layer, width=MLP_width)

    def forward(self, Zkwargs, X):
        embed = self.z_encoder(Zkwargs)
         
        bs, ncol, xdim = X.shape
        bsZ, zdim = embed.shape
        assert bs == bsZ
        embedZ = embed.unsqueeze(1).repeat((1,ncol,1))
        X_enc = self.x_suff_encoder(X)
        state = self.init_model_states(X.shape[0])
        embedState = state.unsqueeze(1).repeat((1,ncol,1))

        input_ = torch.cat([embedZ, X_enc, embedState], 2)
        return self.top_layer(input_)


    def get_device(self):
        return next(self.top_layer.parameters()).device

    
    def init_model_states(self, batch_size):
        mean = torch.ones(self.X_dim).to(self.get_device())*self.init_mean
        cov = torch.eye(self.X_dim).to(self.get_device())
        init_state = torch.cat([mean, cov.flatten()])
        return init_state.repeat((batch_size,1)).repeat_interleave(self.repeat_suffstat, dim=1).to(self.get_device())

    
    def get_state(self, X, Y):
        if Y.shape[1] == 0:
            return self.init_model_states(X.shape[0])
        inv_cov = self.suffstat_eps * torch.eye(self.X_dim).to(self.get_device()) + torch.einsum('ijk,ijl->ikl', X, X)
        cov = torch.linalg.inv(inv_cov)

        bs = X.shape[0]
        prior_mean = torch.ones(self.X_dim).to(self.get_device())*self.init_mean
        mean_num = prior_mean.repeat((bs,1)) + torch.sum( X*Y.unsqueeze(2), axis=1 )
        mean = torch.einsum('ijk,ik->ij', cov, mean_num)

        state = torch.cat([mean, cov.flatten(start_dim=1, end_dim=2)], axis=1)
        return state.repeat_interleave(self.repeat_suffstat, dim=1).to(self.get_device())

        
    # no mask as input, use mask later. 
    def eval_seq(self, Zkwargs, X, Y, true_p=None, N=None, return_preds=False, trainlen=None, exact=False):
        
        if N is None:
            N = Y.shape[1]
        N = min(N, Y.shape[1])
        if trainlen is None: trainlen = N
        encoded_Z = self.z_encoder(Zkwargs)

        bs, zdim = encoded_Z.shape
        embedZ = encoded_Z.unsqueeze(1).repeat((1,N,1))

        # Get all states
        all_states = []
        X_enc = self.x_suff_encoder(X)
        for j in range(N):
            prev_Ys = Y[:,:j]
            prev_Xs_enc = X_enc[:,:j]
            curr_state = self.get_state(prev_Xs_enc, prev_Ys)
            all_states.append(curr_state)

        all_states_cat = torch.cat([x.unsqueeze(1) for x in all_states], axis=1)
        input_ = torch.cat([embedZ, X, all_states_cat], 2) 
        #print(f'!!! eval_seq input shape {input_.shape}, embedZ shape {embedZ.shape}, X shape {X.shape}, all cat shape {all_stats_cat.shape}')
        p_hat_pred = self.top_layer(input_).squeeze(2)
        loss_matrix = nn.functional.binary_cross_entropy(p_hat_pred, Y[:,:N], 
                                                  reduction='none')
        if return_preds:
            return loss_matrix, p_hat_pred

        return loss_matrix

    def fill_table_naive(self, Z, hist_X, hist_Y, hist_mask, eval_X):
        m = eval_X.shape[1]
        assert hist_Y.shape[1] == hist_X.shape[1] # number of timesteps observed
        hist_len = hist_Y.shape[1]
        Y = torch.zeros(hist_Y.shape[0], hist_len + m)
        X = torch.zeros(hist_X.shape[0], hist_len + m, hist_X.shape[2])
        Y[:,:hist_len] = hist_Y
        X[:,:hist_len,:] = hist_X

        for j in range(m):
            X_enc = self.x_suff_encoder(X[:,:j + hist_len,:])
            curr_state = self.get_state(X_enc, Y[:,:j + hist_len])
            input_ = torch.cat([Z.unsqueeze(1), eval_X[:,[j],:], curr_state.unsqueeze(1)], 2)

            # Generate new Y's, for this j
            p_hat_pred = self.top_layer(input_).squeeze(2)
            new_cur_Ys = torch.bernoulli(p_hat_pred)

            # Update relevant parts of Y with the generated Y's
            Y[:,j + hist_len] = new_cur_Ys[:,0]
            X[:,j + hist_len] = eval_X[:,j]

        return Y[:,hist_Y.shape[1]:]


    def fill_table_naive_finite(self, Z, hist_X, hist_Y, hist_mask, eval_X):
        T = hist_mask.shape[1]
        imputedY = torch.clone(hist_Y)
        imputedX = torch.clone(eval_X)
        curr_mask = torch.clone(hist_mask)

        for t in range(T):
                
            curr_mask_exp = curr_mask.reshape( hist_mask.shape[0], hist_mask.shape[1], 1).repeat((1,1,self.X_dim))
            curr_state = self.get_state(self.x_suff_encoder(imputedX)*curr_mask_exp, imputedY*curr_mask)

            # Generate new Y's, for this column / t
            input_ = torch.cat([Z.unsqueeze(1), imputedX[:,[t],:], curr_state.unsqueeze(1)], 2)
            p_hat_pred = self.top_layer(input_).squeeze(2)
            gen_Ys = torch.bernoulli(p_hat_pred)
            
            # Update relevant parts of Y with the generated Y's
            imputedY[:,t] = imputedY[:,t] * curr_mask[:,t] + gen_Ys[:,0] * (1-curr_mask[:,t])
            curr_mask[:,t] = 1
            
        return imputedY


