import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
from flowvi import *


class ToyVI(FlowVI):
    def __init__(self, flow, seed_train, seed_test):
        super().__init__(flow, seed_train, seed_test)
    
    def contour_surface(self, figsize=(6, 3), 
                        xlim=4, ylim=4, x_points=200, y_points=200, 
                        alpha=0.1, lw=0.2, 
                        elev=50, azim=-80, roll=0):
        x = torch.linspace(-xlim, xlim, x_points)
        y = torch.linspace(-ylim, ylim, y_points)
        X, Y = torch.meshgrid(x, y, indexing='xy')
        XY = torch.stack([X, Y], axis=2)           # stack -> (B,2)
        Z = self.logpzx(XY).exp()
        
        fig = plt.figure(figsize=figsize, layout="constrained")        
        ax = fig.add_subplot(1, 2, 1)
        ax.contourf(X, Y, Z, cmap='coolwarm')                   
        ax.locator_params(axis='both', nbins=5)
        
        ax = fig.add_subplot(1, 2, 2, projection='3d')
        ax.plot_surface(X, Y, Z, edgecolor='royalblue', alpha=alpha, lw=lw)
        ax.view_init(elev=elev, azim=azim, roll=roll)


class Toy1VI(ToyVI):
    def __init__(self, flow, seed_train=11235813, seed_test=31415926):
        super().__init__(flow, seed_train, seed_test)
    
    def logpzx(self, z):
        z_norm = (z**2).sum(axis=-1).sqrt()    # (B,2) -> (B,)
        decay  = - 0.5 * ((z_norm-2)/0.4)**2
        frac1  = (z[...,0]-2)/0.6
        frac2  = (z[...,0]+2)/0.6
        logpx  = 1.8775015
        return decay + torch.logaddexp(-0.5*frac1**2, -0.5*frac2**2) - logpx


class Toy2VI(ToyVI):
    def __init__(self, flow, seed_train=11235813, seed_test=31415926):
        super().__init__(flow, seed_train, seed_test)
    
    def logpzx(self, z):
        decay = - 0.5 * (z[...,0]/5)**2
        w1    = torch.sin(math.pi*z[...,0]/2)
        frac1 = (z[...,1]-w1)/0.4
        logpx = math.log(4*math.pi)
        return decay - 0.5*frac1**2 - logpx


class Toy3VI(ToyVI):
    def __init__(self, flow, seed_train=11235813, seed_test=31415926):
        super().__init__(flow, seed_train, seed_test)
    
    def logpzx(self, z):
        decay = - 0.5 * (z[...,0]/5)**2
        w1    = torch.sin(math.pi*z[...,0]/2)
        temp  = (z[...,0]-1)/0.6
        w2    = 3 * torch.exp(-0.5*temp**2)
        frac1 = (z[...,1]-w1)/0.35
        frac2 = (z[...,1]-w1+w2)/0.35
        logpx = math.log(7*math.pi)
        return decay + torch.logaddexp(-0.5*frac1**2, -0.5*frac2**2) - logpx


class Toy4VI(ToyVI):
    def __init__(self, flow, seed_train=11235813, seed_test=31415926):
        super().__init__(flow, seed_train, seed_test)
    
    def logpzx(self, z):
        decay = - 0.5 * (z[...,0]/5)**2
        w1    = torch.sin(math.pi*z[...,0]/2)
        temp  = (z[...,0]-1)/0.3
        w3    = 3 * torch.sigmoid(temp)
        frac1 = (z[...,1]-w1)/0.4
        frac2 = (z[...,1]-w1+w3)/0.35
        logpx = math.log(7.5*math.pi)
        return decay + torch.logaddexp(-0.5*frac1**2, -0.5*frac2**2) - logpx

