#%%
import torch
import torch.nn as nn
import numpy as np
import math
import sys
import normflow as nf
import boltzgen as bg

import matplotlib.pyplot as plt

dims = 1
num_hmc_layers = 10
leapfrog_steps = 5
step_size_const = 0.1
log_mass_const = 0.0

class TargetDist(nn.Module):
    def __init__(self):
        super().__init__()
        self.u = torch.tensor(0.0)
        self.var = torch.tensor(1.0)
    def log_prob(self, z):
        if len(z.shape) == 1:
            z = z.unsqueeze(0)
        log_prob = -0.5 * torch.log(2 * math.pi * self.var) - \
            0.5 * (torch.pow(z - self.u, 2) / self.var)
        log_prob = log_prob.squeeze()
        return log_prob
    
class EIHMC(nn.Module):
    def __init__(self, initial_dist_scale):
        super().__init__()
        self.initial_dist = nf.distributions.DiagGaussian(
            dims, trainable=False)
        self.initial_dist_scale = initial_dist_scale

        raw_flows = []

        self.target_dist = TargetDist()

        # Add HMC layers
        for i in range(num_hmc_layers):
            step_size = step_size_const *  torch.ones((dims,))
            log_step_size = torch.log(step_size)
            log_mass = log_mass_const * torch.ones((dims,))
            raw_flows += [nf.flows.HamiltonianMonteCarlo(self.target_dist,
                leapfrog_steps, log_step_size, log_mass)]

        self.flows = nn.ModuleList(raw_flows)

    def get_hmc_parameters(self):
        params = []
        for i in range(0, len(self.flows)):
            for p in self.flows[i].parameters():
                params.append(p)
        return params


    def sample(self, num_samples):
        x = self.sample_initial_dist(num_samples)
        for flow in self.flows:
            x, _ = flow.forward(x)
        return x

    def sample_initial_dist(self, num_samples):
        x, _ = self.initial_dist.forward(num_samples)
        x = self.initial_dist_scale * x
        return x

    def eval_log_target(self, num_samples):
        x = self.sample(num_samples) 
        log_probs = self.target_dist.log_prob(x)
        return torch.mean(log_probs)

    def get_ELT_over_T(self, num_samples):
        elts = torch.zeros((num_hmc_layers+1, ))
        x = self.sample_initial_dist(num_samples)
        for i in range(len(self.flows)):
            elts[i] = torch.mean(self.target_dist.log_prob(x))
            x, _ = self.flows[i].forward(x)
        elts[-1] = torch.mean(self.target_dist.log_prob(x))
        return elts


thin_eihmc = EIHMC(0.5)
thin_samples_before_train = thin_eihmc.sample(100000).detach().numpy().squeeze()
thin_elts_before_train = thin_eihmc.get_ELT_over_T(100000).detach().numpy()
print("training thin")
optimizer = torch.optim.Adam(thin_eihmc.get_hmc_parameters(), lr=0.05)
for i in range(100):
    optimizer.zero_grad()
    loss = -thin_eihmc.eval_log_target(256)
    loss.backward()
    optimizer.step()
thin_samples_after_train = thin_eihmc.sample(100000).detach().numpy().squeeze()
thin_elts_after_train = thin_eihmc.get_ELT_over_T(100000).detach().numpy()

wide_eihmc = EIHMC(2.0)
wide_elts_before_train = wide_eihmc.get_ELT_over_T(100000).detach().numpy()
wide_samples_before_train = wide_eihmc.sample(100000).detach().numpy().squeeze()
optimizer = torch.optim.Adam(wide_eihmc.get_hmc_parameters(), lr=0.05)
print("training wide")
for i in range(100):
    optimizer.zero_grad()
    loss = -wide_eihmc.eval_log_target(256)
    loss.backward()
    optimizer.step()
wide_elts_after_train = wide_eihmc.get_ELT_over_T(100000).detach().numpy()
wide_samples_after_train = wide_eihmc.sample(100000).detach().numpy().squeeze()

print("Plotting")
plt.plot(thin_elts_before_train)
plt.plot(thin_elts_after_train)
plt.plot(wide_elts_before_train)
plt.plot(wide_elts_after_train)
plt.show()
# %%
