import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from flowvi import *


class GLMVI(FlowVI):
    def __init__(self, n_data, p_nonzero, rho_X, prior_scale, 
                 flow, seed_train, seed_test, seed_glm_data):
        super().__init__(flow, seed_train, seed_test)
        self.prior_scale = prior_scale
        self.data, self.betaTrue = self.generate_data(
            n_data, p_nonzero, rho_X, flow.replicate, seed_glm_data)
    
    def logpzx(self, beta):
        # Compute the unnormalized target density in log scale
        return self.logprior(beta) + self.logL(beta) # (B,)
    
    def logprior(self, beta):
        # Compute the log density of a spike distribution
        scale = self.prior_scale
        log_const = - self.dim * math.log(2*math.pi*scale)
        temp = (scale/beta)**2                       # (B, dim)
        log_kernel = temp.log1p().log().sum(-1)      # (B,)
        return log_const + log_kernel                # (B,)
    
    def logL(self, beta):
        raise NotImplementedError("Loglikelihood is not implemented.")
    
    def generate_data(self, n_data, p_nonzero, rho_X, replicate, seed_glm_data):
        raise NotImplementedError("Data Generation is not implemented.")


class LMVI(GLMVI):
    def __init__(self, n_data, p_nonzero, rho_X, prior_scale, flow, 
                 seed_train=11235813, seed_test=31415926, seed_glm_data=42):
        super().__init__(n_data, p_nonzero, rho_X, prior_scale, flow, 
                         seed_train, seed_test, seed_glm_data)
    
    def logL(self, beta):
        y, X = self.data
        #   (n,p) @ (batch, p, 1)    -> squeeze -> (batch, n) 
        yhat = (X @ beta.unsqueeze(-1)).squeeze(-1)
        #   (n,) unsqueeze(0) + expand          -> (batch, n)
        resi = y.unsqueeze(0).expand(yhat.shape) - yhat
        temp = - 0.5 * y.size(0) * math.log(2*math.pi)  # scalar
        return temp - 0.5 * (resi**2).sum(-1)    # (batch,)
    
    def generate_data(self, n_data, p_nonzero, rho_X, replicate, seed_glm_data):
        # initialize a `rng` and set the state for the j-th replicate 
        # we use a loose upper bound to ensure no overlapping of the draws
        #      X  @ betaTrue + err
        #     n*p      p        n
        n = n_data
        p = self.dim
        upper_per_replicate = ( n*p + p + n ) *1000 
        rng = np.random.default_rng(seed_glm_data)
        rng.bit_generator.advance(upper_per_replicate * replicate)
        
        # prepare regression data
        cov_X = np.array([rho_X**np.abs(np.arange(p)-j) for j in range(p)])
        X = torch.tensor(rng.multivariate_normal(np.zeros(p), cov_X, size=n))
        betaTrue = torch.tensor(rng.uniform(low=-1., high=1., size=p))
        betaTrue[p_nonzero:] = 0.
        y = X @ betaTrue + torch.tensor(rng.normal(size=n))
        data = (y.to(device=self.device), 
                X.to(device=self.device))
        return data, betaTrue


class LogisticVI(GLMVI):
    def __init__(self, n_data, p_nonzero, rho_X, prior_scale, flow, 
                 seed_train=11235813, seed_test=31415926, seed_glm_data=42):
        super().__init__(n_data, p_nonzero, rho_X, prior_scale, flow, 
                         seed_train, seed_test, seed_glm_data)
    
    def logL(self, beta):
        y, X = self.data
        #   (n,p) @ (batch, p, 1)    -> squeeze -> (batch, n) 
        logits = (X @ beta.unsqueeze(-1)).squeeze(-1)
        #   (n,) unsqueeze(0) + expand          -> (batch, n)
        target = y.unsqueeze(0).expand(logits.shape)
        return - F.binary_cross_entropy_with_logits(
            logits, target, reduction='none').sum(-1)  # sum(-1) -> (batch,)
    
    def generate_data(self, n_data, p_nonzero, rho_X, replicate, seed_glm_data):
        # initialize a `rng` and set the state for the j-th replicate 
        # we use a loose upper bound to ensure no overlapping of the draws
        #      h( X  @ betaTrue ) -> y 
        #        n*p     p           n
        n = n_data
        p = self.dim
        upper_per_replicate = ( n*p + p + n ) *1000 
        rng = np.random.default_rng(seed_glm_data)
        rng.bit_generator.advance(upper_per_replicate * replicate)
        
        # prepare regression data
        cov_X = np.array([rho_X**np.abs(np.arange(p)-j) for j in range(p)])
        X = torch.tensor(rng.multivariate_normal(np.zeros(p), cov_X, size=n))
        betaTrue = torch.tensor(rng.uniform(low=-1., high=1., size=p))
        betaTrue[p_nonzero:] = 0.
        prob = torch.sigmoid(X @ betaTrue)
        y = torch.tensor(rng.binomial(1, prob)).to(X.dtype)
        data = (y.to(device=self.device), 
                X.to(device=self.device))
        return data, betaTrue

