import numpy as np
import matplotlib.pyplot as plt
import torch
import math
from torch import nn
from torch import distributions
import torch.nn.functional as F 
from torch.nn.parameter import Parameter


class n_flow_init(nn.Module):
    def __init__(self, dataset_stats, n_flows):
        super().__init__()
        self.n_flows = n_flows
        self.register_buffer('dataset_std_log_sum', torch.tensor([0.0]).repeat(self.n_flows))

    def forward(self, x): #in: x (n_obs, data_dim) out: (n_obs, n_flows, 1, data_dim)
        return x[:,None,None,:].expand(-1,self.n_flows,-1,-1), self.dataset_std_log_sum

    def reverse(self, z, flow_id): #(n_samps, 1, dim)
        return z


class NaiveBatchNormFlow(nn.Module):
    def __init__(self, dataset_stats, n_flows):
        super().__init__()
        self.n_flows = n_flows
        dataset_mean = dataset_stats[0]
        dataset_std = dataset_stats[1]
        self.register_buffer('dataset_mean', torch.tensor(dataset_mean))
        self.register_buffer('dataset_std', torch.tensor(dataset_std))
        self.register_buffer('dataset_std_log', torch.tensor(dataset_std).log())
        self.register_buffer('dataset_std_log_sum', torch.tensor(dataset_std).log().sum().unsqueeze(0))

    def forward(self, x): #in: x (n_obs, data_dim) out: (n_obs, n_flows, 1, data_dim) 
        x_hat = (x - self.dataset_mean) / self.dataset_std
        return x_hat[:,None,None,:].expand(-1,self.n_flows,-1,-1), -self.dataset_std_log_sum.expand(self.n_flows)

    def reverse(self, z, flow_id): #(n_samps, 1, dim) 
        x = z * self.dataset_std + self.dataset_mean
        return x


class ActNorm(nn.Module):
    
    def __init__(self, dim, n_flows):
        super().__init__()
        self.dim = dim 
        self.mu = nn.Parameter(torch.zeros(n_flows,1,dim))
        self.log_sigma = nn.Parameter(torch.zeros(n_flows,1,dim))

    def forward(self, x): #(n_obs, n_flows, 1, dim) 
        z = x * torch.exp(self.log_sigma) + self.mu
        log_det = self.log_sigma.sum(-1).squeeze(-1)
        return z, log_det

    def reverse(self, z, flow_id): #(n_samps, 1, dim)
        x = (z - self.mu[flow_id]) * torch.exp(-self.log_sigma[flow_id]) 
        return x


class nets(nn.Module):

    def __init__(self, in_dim, out_dim, n_flows, n_h=64, n_layers=3):
        super().__init__()
        self.n_layers = n_layers
        self.weights = nn.ParameterList([nn.Parameter(torch.empty(n_flows,in_dim,n_h))])
        self.bias = nn.ParameterList([nn.Parameter(torch.zeros(n_flows,1,n_h))])
        for _ in range(n_layers-2):
            self.weights.append( nn.Parameter(torch.empty(n_flows,n_h,n_h)) )
            self.bias.append( nn.Parameter(torch.zeros(n_flows,1,n_h)) )
        self.weights.append( nn.Parameter(torch.empty(n_flows,n_h,out_dim)) )
        self.bias.append( nn.Parameter(torch.zeros(n_flows,1,out_dim)) )
        [nn.init.xavier_uniform_(_) for _ in self.weights]
        self.act = nn.ReLU(inplace=True)
        #self.act = nn.LeakyReLU(inplace=True)
        self.act_end = nn.Tanh()

    def forward(self, x): #(n_obs, n_flows, 1, dim) 
        for i in range(self.n_layers-1):
            x = self.act(x.matmul(self.weights[i]) + self.bias[i] )
        return self.act_end(x.matmul(self.weights[-1]) + self.bias[-1])

    def reverse(self, x, flow_id): #(n_samps, 1, dim)
        for i in range(self.n_layers-1):
            x = self.act(x.matmul(self.weights[i][flow_id]) + self.bias[i][flow_id] )
        return self.act_end(x.matmul(self.weights[-1][flow_id]) + self.bias[-1][flow_id])


class nett(nn.Module):

    def __init__(self, in_dim, out_dim, n_flows, n_h=64, n_layers=3):
        super().__init__()
        self.n_layers = n_layers
        self.weights = nn.ParameterList([nn.Parameter(torch.empty(n_flows,in_dim,n_h))])
        self.bias = nn.ParameterList([nn.Parameter(torch.zeros(n_flows,1,n_h))])
        for _ in range(n_layers-2):
            self.weights.append( nn.Parameter(torch.empty(n_flows,n_h,n_h)) )
            self.bias.append( nn.Parameter(torch.zeros(n_flows,1,n_h)) )
        self.weights.append( nn.Parameter(torch.empty(n_flows,n_h,out_dim)) )
        self.bias.append( nn.Parameter(torch.zeros(n_flows,1,out_dim)) )
        [nn.init.xavier_uniform_(_) for _ in self.weights]
        self.act = nn.ReLU(inplace=True)
        #self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x): #(n_obs, n_flows, 1, dim)
        for i in range(self.n_layers-1):
            x = self.act(x.matmul(self.weights[i]) + self.bias[i])
        return x.matmul(self.weights[-1]) + self.bias[-1]

    def reverse(self, x, flow_id): #(n_samps, 1, dim)
        for i in range(self.n_layers-1):
            x = self.act(x.matmul(self.weights[i][flow_id]) + self.bias[i][flow_id])
        return x.matmul(self.weights[-1][flow_id]) + self.bias[-1][flow_id]


class Multi_RealNVP(nn.Module):
        
    def __init__(self, dim, n_flows=2, n_h=64, n_layers=3):
        super().__init__()
        self.dim = dim
        self.t1 = nett(in_dim=dim//2, out_dim=dim-dim//2, n_flows=n_flows, n_h=n_h, n_layers=n_layers)
        self.s1 = nets(in_dim=dim//2, out_dim=dim-dim//2, n_flows=n_flows, n_h=n_h, n_layers=n_layers)
        self.t2 = nett(in_dim=dim-dim//2, out_dim=dim//2, n_flows=n_flows, n_h=n_h, n_layers=n_layers)
        self.s2 = nets(in_dim=dim-dim//2, out_dim=dim//2, n_flows=n_flows, n_h=n_h, n_layers=n_layers)

    def forward(self, x): # (n_obs, n_flows, 1, data_dim) 
        lower, upper = x[:,:,:,:self.dim // 2], x[:,:,:,self.dim // 2:]
        t1_transformed = self.t1(lower)
        s1_transformed = self.s1(lower)
        upper = t1_transformed + upper * torch.exp(s1_transformed)
        t2_transformed = self.t2(upper)
        s2_transformed = self.s2(upper)
        lower = t2_transformed + lower * torch.exp(s2_transformed)
        z = torch.cat([lower, upper], dim=-1)
        log_det = s1_transformed.sum(-1).squeeze(-1) + \
                  s2_transformed.sum(-1).squeeze(-1)
        return z, log_det

    def reverse(self, z, flow_id): # #(n_samps, 1, dim)
        lower, upper = z[:,:,:self.dim // 2], z[:,:,self.dim // 2:]
        t2_transformed = self.t2.reverse(upper, flow_id)
        s2_transformed = self.s2.reverse(upper, flow_id)
        lower = (lower - t2_transformed) * torch.exp(-s2_transformed)
        t1_transformed = self.t1.reverse(lower, flow_id)
        s1_transformed = self.s1.reverse(lower, flow_id)
        upper = (upper - t1_transformed) * torch.exp(-s1_transformed)
        x = torch.cat([lower, upper], dim=-1)
        return x


class NormalizingFlowModel(nn.Module):

    def __init__(self, flows, n_flows):
        super().__init__()
        self.n_flows = n_flows
        self.flows = nn.ModuleList(flows)

    def forward(self, x):
        log_det = x.new_zeros(x.shape[0],self.n_flows)
        for flow in self.flows:
            x, ld = flow(x)
            log_det += ld
        return x, log_det

    def reverse(self, z, flow_id):
        x = z.unsqueeze(1)
        for flow in self.flows[::-1]:
            x = flow.reverse(x, flow_id)
        return x.squeeze(1)


class Gaussian_Dist_Exact_nf(nn.Module):

    def __init__(self, dim, n_bits, n_flows):  
        super().__init__()
        self.mu = torch.nn.Parameter(torch.empty(1, int(2**n_bits/n_flows), n_flows,  dim).uniform_(-1., 1.))
        self.sqrt2_sigma_inv = torch.nn.Parameter( torch.empty(1, int(2**n_bits/n_flows), n_flows, dim).uniform_(1e-6, 1.) )
        self.register_buffer("logsqrtpi", torch.tensor(0.5 * math.log(math.pi)))

    def forward(self, z): #z: (n_obs, n_flows, 1, data_dim)
        log_p_z = self.sqrt2_sigma_inv.log() - self.logsqrtpi  -  ( (z.squeeze(2).unsqueeze(1) - self.mu) * self.sqrt2_sigma_inv )**2
        log_p_z = log_p_z.sum(-1)
        return log_p_z  #p_z: (n_obs, 2**n_bits/n_flows, n_flows)


