import torch
import torch.nn as nn
import matplotlib.pyplot as plt

import seaborn as sns
import pandas as pd
import numpy as np

import warnings


class BayesLinear(nn.Module):
  def __init__(self, in_f, out_f, prior_W_std=1.0, prior_b_std=10.0, is_last=False):
    super().__init__()
    # weights, bias means
    self.W_mu = nn.Parameter(torch.randn(out_f, in_f)*1e-3)
    self.b_mu = nn.Parameter(torch.randn(out_f)*1e-3)

    # weights, bias log variances
    self.W_log_var = nn.Parameter(torch.randn(out_f, in_f)*1e-3)
    self.b_log_var = nn.Parameter(torch.randn(out_f)*1e-3)

    # W, b ~ N_i(0, prior...) prior variance on the weights and bias; mean zero
    self.prior_W_std = prior_W_std
    self.prior_b_std = prior_b_std

    self.is_last = is_last

    #self.sample()

  def sample(self):
    self.pert_W_mu = self.W_mu + (0.5 * self.W_log_var).exp() * torch.randn_like(self.W_mu)
    self.pert_b_mu = self.b_mu + (0.5 * self.b_log_var).exp() * torch.randn_like(self.b_mu)

  def forward(self, x):
    # inverted z-score (.5 is a sqrt on log scale)
    W = self.W_mu + (0.5 * self.W_log_var).exp() * torch.randn_like(self.W_mu)
    b = self.b_mu + (0.5 * self.b_log_var).exp() * torch.randn_like(self.b_mu)

    return nn.functional.linear(x, W, b)

  def kl_div(self):
    return (torch.distributions.kl_divergence(
        torch.distributions.Normal(self.W_mu, (0.5*self.W_log_var).exp()),
        torch.distributions.Normal(torch.zeros_like(self.W_mu), torch.ones_like(self.W_mu)*self.prior_W_std),
    ).sum() + torch.distributions.kl_divergence(
        torch.distributions.Normal(self.b_mu, (0.5*self.b_log_var).exp()),
        torch.distributions.Normal(torch.zeros_like(self.b_mu), torch.ones_like(self.b_mu)*self.prior_b_std),
    ).sum())

  def prior_forward(self, x):

    W = torch.zeros_like(self.W_mu) + self.prior_W_std * torch.randn_like(self.W_mu)
    if self.is_last:
      b = torch.zeros_like(self.b_mu) #+ self.prior_b_std * torch.randn_like(self.b_mu)
    else:
      b = torch.zeros_like(self.b_mu) + self.prior_b_std * torch.randn_like(self.b_mu)
    return nn.functional.linear(x, W, b)

class Model(nn.Module):
  def __init__(self, data, mean_bayes=False, prec_bayes=False, hs=5, i=4, alpha=0.0, beta=0.0, oracle_mean=None, oracle_prec=None, grid=None):
    super().__init__()
    if oracle_mean is None:
      if mean_bayes:
        self.b1, self.b2, self.b3, self.b4 = BayesLinear(i, hs, 1/alpha * torch.tensor(i).pow(-.5)), \
                                              BayesLinear(hs, hs,1/alpha * torch.tensor(hs).pow(-.5)), \
                                              BayesLinear(hs, hs, 1/alpha * torch.tensor(hs).pow(-.5)), \
                                              BayesLinear(in_f=hs, out_f=1, prior_W_prec=1/alpha * torch.tensor(hs).pow(-.5), is_last=True)
      else:
        self.b1, self.b2, self.b3, self.b4= nn.Linear(i, hs), nn.Linear(hs, hs), nn.Linear(hs, hs), nn.Linear(hs, 1)
    
    if oracle_prec is None:
      if prec_bayes:
        self.s1, self.s2, self.s3, self.s4 = BayesLinear(i, hs, 1/beta * torch.tensor(i).pow(-.5)), \
                                              BayesLinear(hs, hs,  1/beta * torch.tensor(hs).pow(-.5)), \
                                              BayesLinear(hs, hs, 1/beta * torch.tensor(hs).pow(-.5)), \
                                              BayesLinear(in_f =hs, out_f=1, prior_W_std=1/beta *torch.tensor(hs).pow(-.5), is_last=True)
      else:
        self.s1, self.s2, self.s3, self.s4 = nn.Linear(i, hs), nn.Linear(hs, hs), nn.Linear(hs, hs), nn.Linear(hs, 1)

    self.alpha= torch.tensor(alpha)
    self.beta= torch.tensor(beta)

    self.oracle_mean = oracle_mean
    self.oracle_prec = oracle_prec

    self.act = nn.SiLU()
    self.bayes_layers = []
    self.determ_layers = []
    if oracle_mean is None:
      if mean_bayes:
        self.bayes_layers += [self.b1, self.b2, self.b3, self.b4]
      else:
        self.determ_layers += [self.b1, self.b2, self.b3, self.b4]
    
    if oracle_prec is None:
      if prec_bayes:
        self.bayes_layers += [self.s1, self.s2, self.s3, self.s4]
      else:
        self.determ_layers += [self.s1, self.s2, self.s3, self.s4]

    self.mean_bayes, self.var_bayes = mean_bayes, prec_bayes

    self.data = torch.nn.Parameter(data, requires_grad=True)

  def forward(self, x):
    if self.oracle_mean is None:
      m = self.b1(x)
      m = self.b2(self.act(m))
      m = self.b3(self.act(m))
      m = self.b4(self.act(m))
    else:
      m = torch.reshape(self.oracle_mean, (len(self.oracle_mean), 1))

    if self.oracle_prec is None:
      prec = self.s1(x)
      prec = self.s2(self.act(prec))
      prec = self.s3(self.act(prec))
      prec = self.s4(self.act(prec))
      prec = torch.nn.functional.softplus(prec)+1e-10
    else:
      prec = torch.reshape(self.oracle_prec, (len(self.oracle_prec), 1)) 

    return m, prec

  # revisit
  def prior_forward(self, x):
    if self.oracle_mean is None:
      if self.mean_bayes:
        m = self.b1.prior_forward(x)
        m = self.b2.prior_forward(self.act(m))
        m = self.b3.prior_forward(self.act(m))
        m = self.b4.prior_forward(self.act(m))
      else:
        m = self.b1(x)
        m = self.b2(self.act(m))
        m = self.b3(self.act(m))
        m = self.b4(self.act(m))
    else: 
      m = None

    if self.oracle_prec is None:
      if self.prec_bayes:
        prec = self.s1.prior_forward(x)
        prec = self.s2.prior_forward(self.act(prec))
        prec = self.s3.prior_forward(self.act(prec))
        prec = self.s4.prior_forward(self.act(prec))
        prec = torch.nn.functional.softplus(prec)+1e-10
      else:
        prec = self.s1(x)
        prec = self.s2(self.act(prec))
        prec = self.s3(self.act(prec))
        prec = self.s4(self.act(prec))
        prec = torch.nn.functional.softplus(prec)+1e-8
    else:
      prec = None

    return m, prec

  def elbo(self, y, mu_hat, prec_hat):
    assert(y.shape == mu_hat.shape)
    std_hat = prec_hat.pow(-0.5)
    ll = torch.distributions.Normal(mu_hat, std_hat).log_prob(y).sum()
    kl = 0.0
    for layer in self.bayes_layers:
      kl += layer.kl_div()
    return ll - kl

  def map(self, y, mu_hat, prec_hat):
    std_hat = prec_hat.pow(-0.5)

    if std_hat.isnan().sum() > 0:
      std_hat = torch.nan_to_num(std_hat, 9999999.)
      warnings.warn("invalid nan std")

    if mu_hat.isnan().sum() > 0:
      mu_hat = torch.nan_to_num(mu_hat, 0.)
      warnings.warn("invalid nan mean")


    ll = torch.distributions.Normal(mu_hat, std_hat).log_prob(y).sum()
    log_prior = 0.0

    # hard coded for current architecture with biases
    ct = 0
    for l in self.determ_layers:
      # adding in the mean, variance contributions separately by layer

      # if both are none, then both networks exist
      # first "half" of the parameters have to do with mean (alpha), second are with precision (beta)
      if self.oracle_mean is None and self.oracle_prec is None:
        if ct < len(self.determ_layers) / 2:
          std = self.alpha.pow(-0.5)
        else:
          std = self.beta.pow(-0.5)
      
      # there's no mean network, but there is a variance network
      elif self.oracle_mean is not None and self.oracle_prec is None:
        std = self.beta.pow(-0.5)
      
      # there's a mean network, but there is no variance network
      elif self.oracle_mean is None and self.oracle_prec is not None:
        std = self.alpha.pow(-0.5)

      # in the case where mean/variance have both been passed in then there's no determ layers
      log_prior += torch.distributions.Normal(torch.zeros_like(l.weight), torch.ones_like(l.weight)* std / torch.tensor(l.weight.shape[0]).sqrt() ).log_prob(l.weight).sum()
      log_prior += torch.distributions.Normal(torch.zeros_like(l.bias), torch.ones_like(l.bias)*10.).log_prob(l.bias).sum()

      ct += 1

    return ll + log_prior



class Model2(nn.Module):
  def __init__(self, data, hs=5, i=4, alpha=0.0, beta=0.0, oracle_mean=None, oracle_prec=None):
    super().__init__()
    # initialize mean network if no fixed mean is passed
    if oracle_mean is None:
        self.mean1, self.mean2, self.mean3, self.mean4= nn.Linear(i, hs), nn.Linear(hs, hs), nn.Linear(hs, hs), nn.Linear(hs, 1)
    
    # initialize precision network if no fixed precision is passed
    if oracle_prec is None:
        self.prec1, self.prec2, self.prec3, self.prec4 = nn.Linear(i, hs), nn.Linear(hs, hs), nn.Linear(hs, hs), nn.Linear(hs, 1)

    self.alpha= torch.tensor(alpha)
    self.beta= torch.tensor(beta)

    self.oracle_mean = oracle_mean
    self.oracle_prec = oracle_prec

    self.act = nn.SiLU()
    self.mean_layers = []
    self.prec_layers = []

    if oracle_mean is None:
        self.mean_layers += [self.mean1, self.mean2, self.mean3, self.mean4]
    
    if oracle_prec is None:
        self.prec_layers += [self.prec1, self.prec2, self.prec3, self.prec4]

    # store as a parameter to be able to take a gradient against (for the approx integral)
    self.data = torch.nn.Parameter(data, requires_grad=True)


  def forward(self, x):
    if self.oracle_mean is None:
      mean = self.mean1(x)
      mean = self.mean2(self.act(mean))
      mean = self.mean3(self.act(mean))
      mean = self.mean4(self.act(mean))
    else:
      mean = torch.reshape(self.oracle_mean, (len(self.oracle_mean), 1))

    if self.oracle_prec is None:
      prec = self.prec1(x)
      prec = self.prec2(self.act(prec))
      prec = self.prec3(self.act(prec))
      prec = self.prec4(self.act(prec))
      prec = torch.nn.functional.softplus(prec)+1e-10
    else:
      prec = torch.reshape(self.oracle_prec, (len(self.oracle_prec), 1)) 

    return mean, prec

  # approximately a map loss
  def n_map(self, y, mu_hat, prec_hat):
    un_pen_loss = ((y - mu_hat).pow(2) * prec_hat).mean() - torch.log(prec_hat).mean()
    

    l2_mean = 0
    for ml in self.mean_layers:
      l2_mean += ml.weight.pow(2).sum()

    l2_prec = 0
    for pl in self.prec_layers:
      l2_prec += pl.weight.pow(2).sum()

    return un_pen_loss + self.alpha * l2_mean + self.beta * l2_prec
