from sde_lib import SDE_Brownian_manifolds
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from sampling import SDE_sampler_manifolds
from runners.MD_runner import MD_prior_potential
from manifolds.MD import Manifold_MD
from utils import set_seed_everywhere


@torch.no_grad()
def refine_dataset_md(sdf, samples, tol=1e-5, max_iter_n=1000, step_size=1e-1, rk4=True, keep_quiet=False):

    print("refine data set...")

    if isinstance(samples, torch.Tensor):
        # use double instead of float in order to improve convergence!
        samples = samples.clone().double()
        device = samples.device
    else:
        samples = torch.tensor(samples, dtype=torch.float64)
        device = torch.device("cpu")

    active_idx = torch.arange(0, samples.shape[0], dtype=torch.int64).to(device)

    iter_n = 0
    while iter_n < max_iter_n:
        xi_vals = sdf.constrain_fn(samples[active_idx, :])
        error = torch.abs(xi_vals).squeeze(dim=1)
        bad_idx = (error >= tol)
        if bad_idx.sum() == 0:
            print("break")
            break
        else:
            if iter_n % 50 == 0 and not keep_quiet:
                print(f'iter {iter_n}: max_err={torch.max(error):.3e}, {bad_idx.sum()} bad states, tol={tol:.3e}')
        active_idx = active_idx[bad_idx]

        if rk4 == True :
            k1 = sdf.constrain_fn(samples[active_idx,:])[:,:,None] * sdf.constrain_grad_fn(samples[active_idx, :])
            tmp = samples[active_idx,:] + k1 * step_size * 0.5
            k2 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            tmp = samples[active_idx,:] + k2 * step_size * 0.5
            k3 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            tmp = samples[active_idx,:] + k3 * step_size 
            k4 = sdf.constrain_fn(tmp)[:,:,None] * sdf.constrain_grad_fn(tmp)
            samples[active_idx, :] = samples[active_idx, :] - (k1 + 2*k2 + 2*k3 + k4) * step_size / 6.0
        else :
            samples[active_idx, :] = samples[active_idx, :] - sdf.constrain_fn(samples[active_idx,:])[:,:,None] * sdf.constrain_grad_fn(samples[active_idx, :]) * step_size

        iter_n += 1

    xi_vals = sdf.constrain_fn(samples)
    max_error = torch.max(torch.abs(xi_vals).squeeze())

    print(f'total steps={iter_n}, final error: {max_error: .3e}.')
    if max_error > tol * 1.1:
        print(f'Warning: tolerance ({tol: .3e}) not reached!')

    mask = torch.ones(samples.shape[0], dtype=torch.bool)
    mask[active_idx] = False
    samples_abridged = samples[mask, ...]

    return samples.detach().cpu(), samples_abridged.detach().cpu()


def plot_phi_psi_hist(phi, psi, savefig=""):
    fig = plt.figure(figsize=(10, 5))
    print(phi.shape[0])
    bins = int(phi.shape[0] / 100)

    ax = plt.subplot(1, 2, 1)
    ax.hist(phi.reshape(-1).numpy()/ torch.pi * 180, bins=bins, alpha=1.0, density=True)
    ax.set_title("phi")

    ax = plt.subplot(1, 2, 2)
    ax.hist(psi.reshape(-1).numpy()/ torch.pi * 180, bins=bins, alpha=1.0, density=True)
    ax.set_title("psi")

    plt.savefig(f'./datasets/figs/dipeptide/dipeptide_Hist_{savefig}.png', dpi=300)
    #plt.show()
    plt.close(fig)


def generate_dipeptide_data():
    set_seed_everywhere(777)
    xvec = torch.tensor(np.load(f'./data/dipeptide/dipeptide.npy')).float()
    manifold = Manifold_MD()

    phi, psi = manifold.angle_phi(xvec), manifold.angle_psi(xvec)
    plot_phi_psi_hist(phi, psi, savefig="ori")

    sdf = Manifold_MD()
    # 1e-4, 1e-2
    xvec_refined_all, xvec_refined = refine_dataset_md(sdf, xvec,
                                                    tol=1e-5,
                                                    max_iter_n=10000,
                                                    step_size=1e-3,
                                                    rk4=True,
                                                    keep_quiet=False)
    np.save(f'./data/dipeptide/dipeptide_refined.npy', xvec_refined.numpy())
    np.save(f'./data/dipeptide/dipeptide_ref.npy', xvec_refined[0].numpy())

    print(sdf.constrain_fn(xvec_refined).abs().max())
    phi_refined, psi_refined = manifold.angle_phi(xvec_refined), manifold.angle_psi(xvec_refined)
    plot_phi_psi_hist(phi_refined, psi_refined, savefig="refined")

    return


def gen_data(manifold, repeat_num=1, kappa=100., N=100000, T=1000, sigma=1.):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    xref = torch.tensor(np.load(f'./data/dipeptide/dipeptide_ref.npy')).float()
    # print((xref - xref.mean(dim=0, keepdim=True)).abs().max())
    prior = MD_prior_potential(xref, kappa=kappa)

    sde = SDE_Brownian_manifolds(sigma_min=sigma, sigma_max=sigma, N=N, T=T)
    sde.func_b = lambda x: - prior.gradV(x)

    init = xref.unsqueeze(0).repeat(repeat_num, 1, 1).to(device)
    x, x_hist, other_dict = SDE_sampler_manifolds(sde, manifold, init,
                                                  reverse=False, keep_quiet=False)
    return other_dict["x_hist_all"].detach().cpu().numpy()


def plot_Hist(manifold, dataset_list):
    abs_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

    fig = plt.figure()
    ax = plt.subplot(1, 1, 1)
    for i, data in enumerate(dataset_list):

        psi = manifold.angle_psi(data).numpy().reshape(-1)
        ax.hist(psi, bins=int(psi.shape[0] / 100), histtype='step', alpha=1., density=True, label=f'{i}')

    plt.legend()
    plt.savefig(f"{abs_path}/datasets/figs/dipeptide/temp.png", dpi=300)
    plt.show()
    plt.close(fig)


def generate_dipeptide_prior_data(kappa=50.):
    set_seed_everywhere(777)
    manifold = Manifold_MD()
    data = gen_data(manifold, kappa=kappa, N=100000, T=50., sigma=1.)
    xvec11 = data[1000::10].reshape(-1, 10, 3)
    # plot_Hist(manifold, [xvec3, xvec4, xvec5])

    x_prior = xvec11
    np.save(f'./data/dipeptide/dipeptide_prior_{int(kappa)}.npy', x_prior)

    return


def find_well_center():
    manifold = Manifold_MD()

    xvec = torch.tensor(np.load(f'./data/dipeptide/dipeptide_refined.npy')).float()

    psi = manifold.angle_psi(xvec).reshape(-1) / torch.pi * 180
    idx1 = torch.argmin((psi - (-20.)).abs())
    idx2 = torch.argmin((psi - (150.)).abs())

    x_center_1 = xvec[idx1].clone()
    x_center_2 = xvec[idx2].clone()
    np.save(f'./data/dipeptide/dipeptide_center_1.npy', x_center_1.numpy())
    np.save(f'./data/dipeptide/dipeptide_center_2.npy', x_center_2.numpy())


