import os
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

def Tensor2Numpy(*args):
    if len(args) == 1:
        return args[0].cpu().detach().numpy()
    y = []
    for x in args:
        y.append(x.cpu().detach().numpy())
    return y

def plot_1d(x, save_path=None):
    if isinstance(x, torch.Tensor):
        x = Tensor2Numpy(x).reshape([-1])
    else:
        x = x.reshape([-1])

    fig, axs = plt.subplots(1, 2, figsize=(12, 4))
    axs[0].hist(
            x,
            bins=500,
            alpha=0.5,
            density=True,
            # histtype="step",
            linewidth=4,
        )
    axs[0].set_title("Four Well Distribution (Histogram)")
    axs[0].set_xlabel("x")
    axs[0].set_ylabel("Density")

    sns.kdeplot(
        x,
        ax=axs[1],
        cmap="coolwarm",
        shade=True,
        fill=True,
    )
    axs[1].set_title("Four Well Distribution (KDE)")
    axs[1].set_xlabel("x")
    axs[1].set_ylabel("Density")

    plt.savefig(f'{save_path}.png', bbox_inches='tight')
    plt.savefig(f'{save_path}.pdf', bbox_inches='tight')
    plt.close()


class AsymmetricDoubleWell(object):
    def __init__(self, a=1., b=1.):
        self.a = a
        self.b = b

    def __call__(self, x):
        return self.a*(x**8+0.8*(-80.*x**2).exp()+0.2*(-80.*(x-self.b)**2).exp()+0.5*(-40*(x+self.b)**2).exp())

    def grad(self, x):
        part1 = 8*x**7
        part2 = 0.8*(-80.*x**2).exp()*(-160*x)
        part3 = 0.2*(-80.*(x-self.b)**2).exp()*(-160*(x-self.b))
        part4 = 0.5*(-40*(x+self.b)**2).exp()*(-80*(x+self.b))
        return self.a*(part1 + part2 + part3 + part4)

    def step_langevin(self, x, dt=0.1, kT=15.0, mGamma=1):
        return (x - dt / mGamma * self.grad(x) + torch.randn_like(x) * ((2.0 * kT * dt) / (mGamma)) ** 0.5)

class PerturbedDoubleWell(object):
    def __init__(self, a=4., b=0.5, k=1., center=0.):
        self.a = a
        self.b = b
        self.k = k
        self.dt = None
        self.noise = None
        self.sigma = None

    def __call__(self, x):
        energy = self.a*(x**8+0.8*(-80.*x**2).exp()+0.2*(-80.*(x-self.b)**2).exp()+0.5*(-40*(x+self.b)**2).exp())
        perturb_energy = self.pertub_energy(x)
        return energy + perturb_energy

    def grad_perturb(self, x):
        # return 0.5*(self.k*(x+0.5))*2
        return 2.*(-self.k*(x)**2).exp()*(-2*self.k*(x))
    
    def pertub_energy(self, x):
        # return 0.5*(self.k*(x+0.5)**2)
        return 2.*(-self.k*(x)**2).exp()

    def grad(self, x):
        part1 = 8*x**7
        part2 = 0.8*(-80.*x**2).exp()*(-160*x)
        part3 = 0.2*(-80.*(x-self.b)**2).exp()*(-160*(x-self.b))
        part4 = 0.5*(-40*(x+self.b)**2).exp()*(-80*(x+self.b))
        part5 = self.grad_perturb(x)
        return self.a*(part1 + part2 + part3 + part4) + part5
    
    def step_GR(self, x):
        I_a = self.grad_perturb(x) / self.sigma * self.noise * self.dt**0.5
        R_a = - 0.5 * (self.grad_perturb(x) / self.sigma)**2 * self.dt
        log_g = (self.pertub_energy(x))
        return I_a + R_a, log_g

    def step_langevin(self, x, dt=0.1, kT=15.0, mGamma=1, return_GR=False):
        self.dt = dt
        self.noise = torch.randn_like(x)
        self.sigma = (2.0 * kT / mGamma) ** 0.5
        gr, log_g = self.step_GR(x)
        update_x = (x - dt / mGamma * self.grad(x) + self.noise * self.sigma / mGamma * (dt ** 0.5))
        if return_GR:
            return update_x, gr, log_g
        else:
            return update_x


class LangevinSampler(object):
    def __init__(self, potential, x0=0.0, dt=0.001, kT=15.0, mGamma=1.0):
        self.potential = potential
        self.x = x0
        self.dt = dt
        self.kT = kT
        self.mGamma = mGamma

    def step(self, return_GR=False):
        if return_GR:
            self.x, self.gr, self.log_g = self.potential.step_langevin(self.x, self.dt, self.kT, self.mGamma, return_GR)
        else:
            self.x = self.potential.step_langevin(self.x, self.dt, self.kT, self.mGamma)

    def run(self, nsteps, return_GR=False):
        x = torch.zeros((nsteps, *self.x.shape))
        if return_GR:
            gr = torch.zeros((nsteps, *self.x.shape))
            log_g = torch.zeros((nsteps, *self.x.shape))

        for i in tqdm(range(nsteps)):
            self.step(return_GR)
            x[i] = self.x
            if return_GR:
                gr[i] = self.gr
                log_g[i] = self.log_g
            if torch.isnan(self.x).any():
                print(self.x)
                print("NaN encountered")
                raise ValueError("NaN encountered")
        if return_GR:
            return x, gr, log_g
        else:
            return x

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    n_trajs = 1
    beta = 2.
    a = 4
    b = 0.5
    k = 15.
    n_samples = 10_000_000
    dt = 0.001
    mGamma = 1
    output_dir = '' # TODO: add output dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    x0s = torch.zeros(n_trajs)
    x0s = x0s.to(device)

    samples = []
    grs = []
    log_gs = []
    for i, x0 in tqdm(enumerate(x0s)):
        print(f'Starting trajectory {i+1:02d}/{n_trajs:02d}')
        return_GR = True
        sampler = LangevinSampler(PerturbedDoubleWell(a=a, b=b, k=k), x0=x0, dt=dt, kT=1/beta, mGamma=mGamma)
        x1, gr, log_g = sampler.run(n_samples, return_GR=return_GR)

        # return_GR = False
        # sampler = LangevinSampler(AsymmetricDoubleWell(a=a, b=b), x0=x0, dt=dt, kT=1/beta, mGamma=mGamma)
        # x1 = sampler.run(n_samples, return_GR=False)

        # x1 = x1[1000::10]
        results = {
            'sample': Tensor2Numpy(x1),
            'gr_weights': Tensor2Numpy(gr) if return_GR else None,
            'log_likeli_ratio': Tensor2Numpy(log_g) if return_GR else None,
        }
        if return_GR:
            np.savez(f'{output_dir}/biased_traj{i}.npz', **results)
        else:
            np.savez(f'{output_dir}/unbiased_traj{i}.npz', **results)
        samples.append(x1)
        if return_GR:
            grs.append(gr)
            log_gs.append(log_g)

    samples = torch.cat(samples).reshape([n_trajs, -1, 1]).detach().numpy()
    if return_GR:
        grs = torch.cat(grs).reshape([n_trajs, -1]).detach().numpy()
        log_gs = torch.cat(log_gs).reshape([n_trajs, -1]).detach().numpy()
        plot_1d(samples.reshape(-1), save_path=f'{output_dir}/biase_FW_1d')
    else:
        plot_1d(samples.reshape(-1), save_path=f'{output_dir}/unbiase_FW_1d')
    results = {
        'sample': samples,
        'gr_weights': grs if return_GR else None,
        'log_likeli_ratio': log_gs if return_GR else None,
    }
    if return_GR:
        np.savez(f'{output_dir}/biased_trajtotal.npz', **results)
    else:
        np.savez(f'{output_dir}/unbiased_trajtotal.npz', **results)
    print(samples.shape)

     