import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from time import time
from copy import deepcopy


class FlowVI:
    '''
    Variational Inference with Normalizing Flows
        this class defines the trainer and tester for the given flow model
        the unnormalized density should be defined in the subclass
    '''
    def __init__(self, flow, seed_train, seed_test):
        self.flow   = flow
        self.dim    = flow.dim
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.rng_train = torch.Generator(device=self.device)
        self.rng_train.manual_seed(seed_train)
        self.seed_test = seed_test
    
    def train_test(self, 
                   num_loops, 
                   optimizer, scheduler, 
                   num_updates_per_loop, batch_size, 
                   num_tests, batch_size_test, verbose=True):
        t_init = time()
        test_loss_lst = []
        
        t0 = time()
        test_loss = self.test_loop(num_tests, batch_size_test)
        test_loss_lst.append(test_loss)
        if verbose:
            print(f'Initial   -ELBO = {test_loss:>7f}  {round(time()-t0)} sec')
        
        for i_loop in range(1, num_loops+1):
            t0 = time()
            self.train_loop(num_updates_per_loop, batch_size, optimizer) # train
            test_loss = self.test_loop(num_tests, batch_size_test)    # validate
            scheduler.step()                                          # step lr
            
            test_loss_lst.append(test_loss) 
            if verbose:
                print(f'Loop {i_loop:>3}  -ELBO = {test_loss:>7f}  ' \
                      f'{round(time()-t0)} sec')
        
        model_state = deepcopy(self.flow.state_dict())
        print(f'Time for Training and Tesing {num_loops} loops' \
              f'  {round((time()-t_init)/60, 1)} min')
        return {'model_state':model_state, 
                'test_loss_hist':torch.stack(test_loss_lst)} # (num_epochs+1,)
    
    def train_loop(self, num_updates_per_loop, batch_size, optimizer):
        self.flow.train()
        for i_update in range(num_updates_per_loop):
            z = torch.randn(batch_size, self.dim, 
                            generator=self.rng_train, device=self.device)
            zk, log_jac_det = self.flow(z)
            loss = self.loss_fn(zk, log_jac_det)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    @torch.no_grad()
    def test_loop(self, num_tests, batch_size_test):
        self.flow.eval()
        rng_test = torch.Generator(device=self.device)
        rng_test.manual_seed(self.seed_test)
        test_loss = torch.zeros(num_tests)
        for i_test in range(num_tests):
            z = torch.randn(batch_size_test, self.dim, 
                            generator=rng_test, device=self.device)
            zk, log_jac_det = self.flow(z)
            test_loss[i_test] = self.loss_fn(zk, log_jac_det)
        return test_loss.mean()
    
    def loss_fn(self, zk, log_jac_det):
        logpzx    = self.logpzx(zk)                             # (B,)
        temp      = (log_jac_det + logpzx).mean()               # scalar
        Elogq0_z0 = - 0.5 * self.dim * (math.log(2*math.pi)+1)  # scalar
        neg_elbo  = Elogq0_z0 - temp
        return neg_elbo
    
    def logpzx(self, zk):
        raise NotImplementedError("unnormalized density is not implemented.")

