import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from torch.utils.data import DataLoader
from time import time
from copy import deepcopy
from special import ExpL


class ToBinary:
    def __init__(self):
        pass

    def __call__(self, x):
        return torch.round(x)


class VAEFlow(nn.Module):
    def __init__(self, latent_dim, num_channels, flow, 
                 seed_draw=31415926):
        super().__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.latent_dim = latent_dim
        self.C1 = num_channels[1]
        self.gen = torch.Generator(device=self.device)
        self.gen.manual_seed(seed_draw)
        self.iteration = 0
        self.anneal_rate = 1.
        
        # encoder
        C0, C1 = num_channels
        self.conv = nn.Sequential(
            # in_channels  = 1 <- grayscale image
            # out_channels = C0
            # kernel_size  = 3   (stride=1 by default)
            # max pooling with kernel_size=2 (stride=kernel_size by default)
            nn.Conv2d(1,  C0, 3, padding='same'), # -> (B, C0, 28, 28)
            nn.ReLU(), 
            nn.MaxPool2d(2),                      # -> (B, C0, 14, 14)
            nn.Conv2d(C0, C1, 3, padding='same'), # -> (B, C1, 14, 14)
            nn.ReLU(), 
            nn.MaxPool2d(2)                       # -> (B, C1, 7,  7)
        )
        self.fc_mu = nn.Linear(C1*49, latent_dim) # C1*7*7
        self.fc_sd = nn.Sequential(
            nn.Linear(C1*49, latent_dim),
            ExpL()
        )
        
        # flow
        self.flow = flow
        
        # decoder
        self.fc_decoder = nn.Sequential(
            nn.Linear(latent_dim, C1*49),             # view()-> (B, C1, 7, 7)
            nn.ReLU()
        )
        self.convTrans = nn.Sequential(
            nn.ConvTranspose2d(C1, C0, 3, stride=2,
                               padding=1, output_padding=1), #-> (B, C0, 14, 14) 
            nn.ReLU(),
            nn.ConvTranspose2d(C0,  1, 3, stride=2, 
                               padding=1, output_padding=1)  #-> (B, 1, 28, 28)
        )
    
    def encode(self, x):
        conv_out = self.conv(x).view(x.size(0), -1)
        mu = self.fc_mu(conv_out)
        sd = self.fc_sd(conv_out)
        return mu, sd
    
    def draw_from_flow(self, mu, sd):
        eps = torch.randn(sd.size(), generator=self.gen, device=self.device)
        z0 = mu + sd * eps
        zk, log_jac_det = self.flow(z0)
        return zk, log_jac_det, z0
    
    def decode(self, zk):
        hidden = self.fc_decoder(zk).view(-1, self.C1, 7, 7)
        logits = self.convTrans(hidden)
        return logits
    
    def forward(self, x):
        mu, sd = self.encode(x)
        zk, log_jac_det, z0 = self.draw_from_flow(mu, sd)
        logits = self.decode(zk)
        return logits, zk, log_jac_det, z0, mu, sd
    
    def loss_info_fn(self, x, logits, zk, log_jac_det, mu, sd):
        Elogq0_z0_sum = - (sd.log() + 0.5).sum()          # (B,D) -> scalar
        Elogqk_zk_sum = Elogq0_z0_sum - log_jac_det.sum() # (B,)  -> scalar
        logp_zk_sum   = - 0.5 * (zk**2).sum()             # (B,D) -> scalar
        kld = (Elogqk_zk_sum - logp_zk_sum) / x.size(0)
        
        neg_logL = F.binary_cross_entropy_with_logits(    # (B,1,28,28)
            logits, x, reduction='sum') / x.size(0)       #       -> scalar
        return {'kld':kld, 'neg_logL':neg_logL}
    
    def get_anneal_beta(self):
        return min(1., max(0.01, self.anneal_rate*self.iteration))
    
    def train_loop(self, train_loader, optimizer, anneal):
        self.train()
        for x, _ in train_loader:
            x = x.to(self.device)
            logits, zk, log_jac_det, z0, mu, sd = self(x)
            loss_info = self.loss_info_fn(x, logits, zk, log_jac_det, mu, sd)
            if anneal:
                self.iteration += 1
                beta = self.get_anneal_beta()
                loss = beta * loss_info['kld'] + loss_info['neg_logL']
            else:
                loss =        loss_info['kld'] + loss_info['neg_logL']
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    @torch.no_grad()
    def test_loop(self, test_loader):
        self.eval()
        kld, neg_logL = [], []
        for x, _ in test_loader:
            x = x.to(self.device)
            logits, zk, log_jac_det, z0, mu, sd = self(x)
            loss_info = self.loss_info_fn(x, logits, zk, log_jac_det, mu, sd)
            kld.append(loss_info['kld'])
            neg_logL.append(loss_info['neg_logL'])
        kld_mean      = torch.stack(kld).mean()
        neg_logL_mean = torch.stack(neg_logL).mean()
        test_loss = torch.stack((
            kld_mean+neg_logL_mean, kld_mean, neg_logL_mean))
        return test_loss   # (3,)
    
    def train_test(self, 
                   num_epochs, 
                   optimizer, scheduler, 
                   train_loader, test_loader, 
                   anneal=True, anneal_epochs=100, verbose=True):
        t_init = time()
        test_loss_lst = []
        if anneal:
            num_iters_per_epoch = len(train_loader)
            self.anneal_rate = 1. / (anneal_epochs*num_iters_per_epoch)
        
        t0 = time()
        test_loss_i = self.test_loop(test_loader)
        test_loss_lst.append(test_loss_i)
        best_epoch     = 0
        best_test_loss = torch.nan_to_num(test_loss_i[0], nan=torch.inf).item()
        best_model_state = deepcopy(self.state_dict())
        
        if verbose:
            print(f'Initial  ' \
                  f'Test Loss: {test_loss_i[0]:>7f}  ' \
                  f'kld: {test_loss_i[1]:>7f}  ' \
                  f'neg_logL: {test_loss_i[2]:>7f}  ' \
                  f'{round(time()-t0)} sec')
        
        for i_epoch in range(1, num_epochs+1):
            t0 = time()
            self.train_loop(train_loader, optimizer, anneal) # train
            test_loss_i = self.test_loop(test_loader)        # validate
            scheduler.step(test_loss_i[0].item())            # step lr
            
            test_loss_lst.append(test_loss_i)
            if test_loss_i[0].item() < best_test_loss:
                best_epoch       = i_epoch
                best_test_loss   = test_loss_i[0].item()
                best_model_state = deepcopy(self.state_dict())
            
            if verbose:
                print(f'Epoch {i_epoch}  ' \
                      f'Test Loss: {test_loss_i[0]:>7f}  ' \
                      f'kld: {test_loss_i[1]:>7f}  ' \
                      f'neg_logL: {test_loss_i[2]:>7f}  ' \
                      f'{round(time()-t0)} sec')
        print(f'Time for Training and Tesing {num_epochs} epochs' \
              f'  {round((time()-t_init)/60, 1)} min')
        return {'best_model_state':best_model_state, 
                'best_test_loss':best_test_loss,
                'best_epoch':best_epoch,
                'test_loss_hist':torch.stack(test_loss_lst)} # (num_epochs+1,3)
    
    @torch.no_grad()
    def evaluate(self, test_data, IS_size=500, reduction='mean'):
        self.eval()
        '''
        Evaluate model by importance sampling
            estimate marginal likelihood and ELBO on the test dataset
            then compute KL(q_K||posterior)
            report -ELBO, -logp(x), and KL(q_K||posterior)
        '''
        t_init = time()
        test_loader = DataLoader(test_data, batch_size=1)
        neg_logpx_lst = []
        kld_vi_lst    = []
        neg_elbo_lst  = []
        for x, _ in test_loader:
            x = x.to(self.device)
            mu, sd = self.encode(x)         # (1,D) <- (1,1,P,P)
            mu = mu.expand(IS_size, self.latent_dim)    # (IS,D)
            sd = sd.expand(IS_size, self.latent_dim)    # (IS,D)
            zk, log_jac_det, z0 = self.draw_from_flow(mu,sd) # (IS,D)(IS,)(IS,D)
            logits = self.decode(zk)      # (IS,1,P,P) <- (IS,D)
            logL = - F.binary_cross_entropy_with_logits(
                logits, x.expand(logits.shape), 
                reduction='none').sum((-1,-2,-3))       # (IS,)
            logp_zk  = - 0.5 * (zk**2).sum(-1)          # (IS,)
            quadra   = ((z0-mu)/sd)**2                  # (IS,D)
            logq0_z0 = -(sd.log() + 0.5*quadra).sum(-1) # (IS,)
            logqk_zk = logq0_z0 - log_jac_det           # (IS,)
            
            # marginal loglikelihood
            logpq = logL + logp_zk - logqk_zk           # (IS,)
            logpx = logpq.logsumexp(0) - math.log(IS_size)  # scalar
            neg_logpx_lst.append(-logpx)
            
            # negELBO, we could take expectation on one term
            #    E(quadra) = 1 (cf. logq0_z0 above)
            Elogq0_z0 = - (sd.log() + 0.5).sum(-1)  # (IS,)
            Elogqk_zk = Elogq0_z0 - log_jac_det     # (IS,)
            neg_elbo  = (Elogqk_zk - logp_zk - logL).mean()  # (IS,) -> scalar
            neg_elbo_lst.append(neg_elbo)
            
            # KL(q_K||posterior) = logp(x) - ELBO
            kld_vi_lst.append(logpx + neg_elbo)
        
        print(f'Time for Estimating -ELBO, -logp(x), and KL(qK||posterior):' \
              f'  {round((time()-t_init)/60, 1)} min')
        result = torch.stack((torch.stack(neg_elbo_lst), 
                              torch.stack(neg_logpx_lst), 
                              torch.stack(kld_vi_lst)))    # (3,10000)
        if reduction=='mean':
            return result.mean(1)   # (3,)
        elif reduction=='none': 
            return result           # (3,10000)
        else:
            raise Exception('Invalid reduction')

