import matplotlib.pyplot as plt
import torch
import numpy as np
import logging
from runners.Basic_runner import BasicRunner
from utils import (
    split_dataset,
    check_memory,
    get_RMSD,
    Kabsch)
from sampling import SDE_sampler_manifolds


class MD_prior_potential:
    def __init__(self, xref, kappa, prior_sample=None):
        self.xref = xref
        self.kappa = kappa
        if prior_sample is not None:
            self.prior_sample = torch.tensor(prior_sample).float()

    def V(self, x):
        xref = self.xref.to(x)
        R, b = Kabsch(x, xref)
        b0 = torch.mean(xref, 0, True)
        potential = 0.5 * torch.sum((x - b - torch.matmul(xref - b0, R)) ** 2, (1, 2))
        return potential * self.kappa

    @torch.no_grad()
    def gradV(self, x):
        xref = self.xref.to(x)
        R, b = Kabsch(x, xref)
        b0 = torch.mean(xref, 0, True)
        grad = x - b - torch.matmul(xref - b0, R)
        return grad * self.kappa

    def prior_sampler(self, n):
        return self.prior_sample[torch.randint(0, self.prior_sample.shape[0], (n,))]
        # return self.prior_sample[torch.randperm(self.prior_sample.shape[0])][:n]


class MDRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)

        self.load_data()

        """---------------------------exhibit dataset--------------------------"""
        phi_ref = self.manifold.angle_phi(self.xref.unsqueeze(0))/ torch.pi * 180
        self.psi_ref = self.manifold.angle_psi(self.xref.unsqueeze(0))/ torch.pi * 180
        logging.info(f"x reference: phi: {phi_ref.item():.2f}, psi: {self.psi_ref.item():.2f}.")

        samples_test = self.training_set[:self.config.sample.sample_num].detach()
        self.plot_sample_hist(samples_test.cpu().numpy(), savefig="training_set")
        self.plot_sample_hist(self.prior.prior_sample.cpu().numpy(), savefig="prior_dist")

        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0,1)
            x = x_hist[-1].cpu().numpy()
            self.plot_angle_and_RMSD_hist(x, savefig='forward_end')
            plot_idx = list(range(10)) + list(range(10, 101, 10))
            for i in range(self.sde.N+1):
                if (100 * i / self.sde.N in plot_idx) or (i < 5):
                    x_temp = x_hist[i].cpu().numpy()
                    self.plot_angle_and_RMSD_hist(x_temp, savefig=f'generating_fwd_{i}')
            np.save(f"{self.samples_dir}/{self.dataset_name}_hist_fwd.npy", x_hist.cpu().detach().numpy())

        return

    def load_data(self):
        self.kappa = self.config.training.kappa
        self.xref = torch.tensor(np.load(f'./data/dipeptide/dipeptide_ref.npy')).float()
        self.x_center_1 = torch.tensor(np.load(f'./data/dipeptide/dipeptide_center_1.npy')).float()
        self.x_center_2 = torch.tensor(np.load(f'./data/dipeptide/dipeptide_center_2.npy')).float()
        data_ori = torch.tensor(np.load(f'./data/dipeptide/dipeptide_refined.npy')).float()

        prior_sample = np.load(f'./data/dipeptide/dipeptide_prior_{int(self.kappa)}.npy')
        self.prior = MD_prior_potential(self.xref, kappa=self.kappa, prior_sample=prior_sample)
        self.natom = 10

        self.sde.func_b = lambda x: - self.prior.gradV(x)

        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        if self.config.if_train or self.config.if_sample:
            self.training_set_path = self.generate_path_dataset(self.training_set, keep_quiet=False)
            check_memory(self.training_set_path)

    def plot_sample_hist(self, samples, savefig=None):
        phi, psi = self.manifold.angle_phi(samples), self.manifold.angle_psi(samples)
        fig = plt.figure(figsize=(10, 5))
        bins = int(phi.shape[0] / 100)

        ax = plt.subplot(1, 2, 1)
        ax.hist(phi.reshape(-1).numpy() / torch.pi * 180, bins=bins, alpha=1.0, density=True)
        ax.set_title("phi")

        ax = plt.subplot(1, 2, 2)
        ax.hist(psi.reshape(-1).numpy() / torch.pi * 180, bins=bins, alpha=1.0, density=True)
        ax.set_title("psi")

        psi_ref = self.psi_ref.reshape(-1).numpy()
        ax.axvline(x=psi_ref, color='red', linestyle='--')

        plt.suptitle(f"Histogram for {samples.shape[0]} samples.")
        plt.savefig(self.savefig_dir + f"/Hist_psi_phi_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def plot_angle_and_RMSD_hist(self, samples, savefig=None):

        if not isinstance(samples, torch.Tensor): samples = torch.tensor(samples).float()

        true_set = self.training_set.detach()

        psi = self.manifold.angle_psi(samples)
        psi_true = self.manifold.angle_psi(true_set)
        psi_angle = psi.reshape(-1).numpy() / np.pi * 180
        psi_angle_true = psi_true.reshape(-1).numpy() / np.pi * 180

        psi_center_1 = self.manifold.angle_psi(self.x_center_1.unsqueeze(0)) / torch.pi * 180
        psi_center_2 = self.manifold.angle_psi(self.x_center_2.unsqueeze(0)) / torch.pi * 180

        RMSD1 = get_RMSD(samples, self.x_center_1).numpy()
        RMSD1_true = get_RMSD(true_set, self.x_center_1).numpy()

        RMSD2 = get_RMSD(samples, self.x_center_2).numpy()
        RMSD2_true = get_RMSD(true_set, self.x_center_2).numpy()

        fig = plt.figure(figsize=(15, 5))
        bins = int(psi.shape[0] / 100)

        ax = plt.subplot(1, 3, 1)
        ax.hist(psi_angle, bins=bins, alpha=0.5, density=True, color='green')
        ax.hist(psi_angle_true, bins=bins, alpha=0.5, density=True, color='red')

        ax.axvline(x=psi_center_1.reshape(-1).numpy(), color='black', linestyle='--')
        ax.axvline(x=psi_center_2.reshape(-1).numpy(), color='black', linestyle='--')
        ax.set_title("psi")

        ax = plt.subplot(1, 3, 2)
        ax.hist(RMSD1, bins=bins, alpha=0.5, density=True, color='green')
        ax.hist(RMSD1_true, bins=bins, alpha=0.5, density=True, color='red')
        ax.set_title("RMSD1")

        ax = plt.subplot(1, 3, 3)
        ax.hist(RMSD2, bins=bins, alpha=0.5, density=True, color='green')
        ax.hist(RMSD2_true, bins=bins, alpha=0.5, density=True, color='red')
        ax.set_title("RMSD2")

        plt.suptitle(f"Histogram for {samples.shape[0]} samples.")
        plt.savefig(self.savefig_dir + f"/Hist_angel_RMSD_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start':
            pass
        elif mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
        
            # sample
            init = self.prior.prior_sampler(self.config.sample.sample_num).to(self.device)
            x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                          reverse=True,
                                                          score_net=self.network,
                                                          keep_quiet=True)
            x = x.cpu().numpy()
            self.plot_angle_and_RMSD_hist(x, savefig=f'val_{epoch}_generated')

            logging.info("-------------------------End validating.-------------------------")

    def test(self):
        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)

        # backward
        logging.info("Start sampling backward SDE.")
        init = self.prior.prior_sampler(self.config.sample.sample_num).to(device)
        x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init, 
                                                      reverse=True,
                                                      score_net=self.network,
                                                      keep_quiet=False)
        self.calculate_constrain(x)
        x = x.cpu().numpy()
        self.plot_angle_and_RMSD_hist(x, savefig='generated')
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                x_temp = x_hist[i].cpu().numpy()
                self.plot_angle_and_RMSD_hist(x_temp, savefig=f'generating_bwd_{i}')

        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x)
        np.save(f"{self.samples_dir}/{self.dataset_name}_hist_bwd.npy", x_hist.cpu().detach().numpy())

        return


