import matplotlib.pyplot as plt
import torch
import numpy as np
import logging
from runners.Basic_runner import BasicRunner
from utils import (
    split_dataset,
    check_memory,
    get_RMSD,
    Kabsch)
from sampling import SDE_sampler_manifolds_CLangevin


class MD_prior_potential:
    def __init__(self, xref, kappa, prior_sample=None):
        self.xref = xref
        self.kappa = kappa
        if prior_sample is not None:
            self.prior_sample = torch.tensor(prior_sample).float()

    def V(self, x):
        xref = self.xref.to(x)
        R, b = Kabsch(x, xref)
        b0 = torch.mean(xref, 0, True)
        potential = 0.5 * torch.sum((x - b - torch.matmul(xref - b0, R)) ** 2, (1, 2))
        return potential * self.kappa

    @torch.no_grad()
    def gradV(self, x):
        xref = self.xref.to(x)
        x = x.view(-1, self.xref.shape[0], self.xref.shape[1])
        R, b = Kabsch(x, xref)
        b0 = torch.mean(xref, 0, True)
        grad = x - b - torch.matmul(xref - b0, R)
        return (grad * self.kappa).view(-1, self.xref.shape[0] * self.xref.shape[1])

    def prior_sampler(self, n):
        return self.prior_sample[torch.randint(0, self.prior_sample.shape[0], (n,))]


class MDRunner(BasicRunner):
    def __init__(self, config):
        super().__init__(config)

        self.load_data()

        """---------------------------exhibit dataset--------------------------"""
        phi_ref = self.manifold.angle_phi(self.xref.unsqueeze(0))/ torch.pi * 180
        self.psi_ref = self.manifold.angle_psi(self.xref.unsqueeze(0))/ torch.pi * 180
        logging.info(f"x reference: phi: {phi_ref.item():.2f}, psi: {self.psi_ref.item():.2f}.")

        samples_test = self.training_set[:self.config.sample.sample_num].detach()
        self.plot_sample_hist(samples_test.cpu().numpy(), savefig="training_set")
        self.plot_sample_hist(self.prior.prior_sample.cpu().numpy(), savefig="prior_dist")

        if self.config.if_train or self.config.if_sample:
            x_hist = self.training_set_path[:self.config.sample.sample_num].clone().transpose(0,1)
            x = x_hist[-1].cpu().numpy()
            self.plot_angle_and_RMSD_hist(x, savefig='forward_end')
            plot_idx = list(range(10)) + list(range(10, 101, 10))
            for i in range(self.sde.N+1):
                if (100 * i / self.sde.N in plot_idx) or (i < 5):
                    x_temp = x_hist[i].cpu().numpy()
                    self.plot_angle_and_RMSD_hist(x_temp, savefig=f'generating_fwd_{i}')
            np.save(f"{self.samples_dir}/{self.dataset_name}_hist_fwd.npy", x_hist.cpu().detach().numpy())

        return

    def load_data(self):
        """
        Loads the dipeptide datasets for the region around psi=150 degrees.
        """
        self.kappa = self.config.training.kappa
        data_path = './data/dipeptide/'

        # --- UPDATED FILE PATHS ---
        # Load the reference structure from the fully constrained dataset.
        ref_path = f'{data_path}dipeptide_ref_phi_psiwin.npy'
        self.xref = torch.tensor(np.load(ref_path)).float()

        # --- MODIFICATION: Load only the second well center (psi approx 150 deg) ---
        center_path = f'{data_path}dipeptide_center.npy'
        self.x_center = torch.tensor(np.load(center_path)).float()

        # Load the main training data from the fully constrained dataset.
        dataset_path = f'{data_path}dipeptide_refined_phi_psiwin.npy'
        data_ori = torch.tensor(np.load(dataset_path)).float()

        # Load the prior distribution generated via MD simulation.
        prior_path = f'{data_path}dipeptide_prior_{int(self.kappa)}.npy'
        prior_sample = np.load(prior_path)
        # --- END OF UPDATED FILE PATHS ---
        
        self.prior = MD_prior_potential(self.xref, kappa=self.kappa, prior_sample=prior_sample)
        self.natom = 10

        self.sde.func_b = lambda x: - self.prior.gradV(x)

        self.data_set = data_ori[torch.randperm(data_ori.shape[0])].clone()
        self.training_set, self.test_set, self.val_set = split_dataset(self.data_set, self.config.seed)

        self.training_set = self.training_set.reshape(-1, self.natom * 3)
        self.test_set = self.test_set.reshape(-1, self.natom * 3)
        self.val_set = self.val_set.reshape(-1, self.natom * 3)

        if self.config.if_train or self.config.if_sample:
            self.training_set_path = self.generate_path_dataset(self.training_set, keep_quiet=False)[0]
            check_memory(self.training_set_path)

    def plot_sample_hist(self, samples, savefig=None):  
        samples = torch.tensor(samples).reshape(-1, self.natom, 3).float()

        phi, psi = self.manifold.angle_phi(samples), self.manifold.angle_psi(samples)
        fig = plt.figure(figsize=(10, 5))
        bins = int(phi.shape[0] / 100) if phi.shape[0] >=100 else 10

        ax = plt.subplot(1, 2, 1)
        ax.hist(phi.reshape(-1).numpy() / torch.pi * 180, bins=bins, alpha=1.0, density=True)
        ax.set_title("phi")

        ax = plt.subplot(1, 2, 2)
        ax.hist(psi.reshape(-1).numpy() / torch.pi * 180, bins=bins, alpha=1.0, density=True)
        ax.set_title("psi")

        psi_ref = self.psi_ref.reshape(-1).numpy()
        ax.axvline(x=psi_ref, color='red', linestyle='--')
        
        if hasattr(self.manifold, 'psi_windows_rad') and self.manifold.l > 0:
            psi_windows_deg = np.rad2deg(self.manifold.psi_windows_rad.numpy())
            for (low, high) in psi_windows_deg:
                ax.axvspan(low, high, color='blue', alpha=0.1, zorder=-1)
                ax.axvline(x=low, color='blue', linestyle=':', linewidth=1.5)
                ax.axvline(x=high, color='blue', linestyle=':', linewidth=1.5)

        plt.suptitle(f"Histogram for {samples.shape[0]} samples.")
        plt.savefig(self.savefig_dir + f"/Hist_psi_phi_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def plot_angle_and_RMSD_hist(self, samples, savefig=None):
        """
        Modified to only plot psi and RMSD with respect to the second center (x_center_2).
        """
        if not isinstance(samples, torch.Tensor): samples = torch.tensor(samples).float()
        true_set = self.training_set.detach()
        
        samples = samples.reshape(-1, self.natom, 3)
        true_set = true_set.reshape(-1, self.natom, 3)

        psi = self.manifold.angle_psi(samples)
        psi_true = self.manifold.angle_psi(true_set)
        psi_angle = psi.reshape(-1).numpy() / np.pi * 180
        psi_angle_true = psi_true.reshape(-1).numpy() / np.pi * 180

        psi_center = self.manifold.angle_psi(self.x_center.unsqueeze(0)) / torch.pi * 180

        RMSD = get_RMSD(samples, self.x_center).numpy()
        RMSD_true = get_RMSD(true_set, self.x_center).numpy()

        # --- MODIFICATION: Changed figure layout to 1x2 ---
        fig = plt.figure(figsize=(10, 5))
        bins = int(psi.shape[0] / 100) if psi.shape[0] >=100 else 10

        # --- Plot 1: Psi angle distribution ---
        ax = plt.subplot(1, 2, 1)
        ax.hist(psi_angle, bins=bins, alpha=0.5, density=True, color='green', label='Generated')
        ax.hist(psi_angle_true, bins=bins, alpha=0.5, density=True, color='red', label='True')
        ax.legend()
        ax.axvline(x=psi_center.reshape(-1).numpy(), color='black', linestyle='--')
        ax.set_title("psi")
        
        if hasattr(self.manifold, 'psi_windows_rad') and self.manifold.l > 0:
            psi_windows_deg = np.rad2deg(self.manifold.psi_windows_rad.numpy())
            for (low, high) in psi_windows_deg:
                ax.axvspan(low, high, color='blue', alpha=0.1, zorder=-1)
                ax.axvline(x=low, color='blue', linestyle=':', linewidth=1.5)
                ax.axvline(x=high, color='blue', linestyle=':', linewidth=1.5)

        # --- Plot 2: RMSD to center ---
        ax = plt.subplot(1, 2, 2)
        ax.hist(RMSD, bins=bins, alpha=0.5, density=True, color='green')
        ax.hist(RMSD_true, bins=bins, alpha=0.5, density=True, color='red')
        ax.set_title("RMSD to Center (psi~150)")
        # --- END OF MODIFICATION ---

        plt.suptitle(f"Histogram for {samples.shape[0]} samples.")
        plt.savefig(self.savefig_dir + f"/Hist_angel_RMSD_{savefig}.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def validate(self, mode=None, epoch=0, **kwargs):
        if mode == 'start' or mode == 'end':
            pass
        else:
            logging.info(f"-------------------------Start validating: Epoch {epoch}-------------------------")
            
            init = self.prior.prior_sampler(self.config.sample.sample_num).reshape(-1, self.natom * 3).to(self.device)
            x, _, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                   reverse=True,
                                                   score_net=self.network,
                                                   keep_quiet=True, **self.sde_kwargs)
            x = x.cpu().numpy()
            self.plot_angle_and_RMSD_hist(x, savefig=f'val_{epoch}_generated')

            logging.info("-------------------------End validating.-------------------------")

    def test(self):
        return

    def sample_on_manifolds(self):
        logging.info(f'Start sampling on manifolds.')
        device = self.device
        if self.network is not None: self.network.to(device)

        logging.info("Start sampling backward SDE.")
        init = self.prior.prior_sampler(self.config.sample.sample_num).to(device)
        x, x_hist, _ = self.SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                  reverse=True,
                                                  score_net=self.network,
                                                  keep_quiet=False, **self.sde_kwargs)
        self.calculate_constrain(x)
        x = x.cpu().numpy()
        self.plot_angle_and_RMSD_hist(x, savefig='generated')
        
        plot_idx = list(range(0, 100, 10)) + list(range(90, 101))
        for i in range(self.sde.N+1):
            if (100 * i / self.sde.N in plot_idx) or (i > self.sde.N - 5):
                x_temp = x_hist[i].cpu().numpy()
                self.plot_angle_and_RMSD_hist(x_temp, savefig=f'generating_bwd_{i}')

        np.save(f"{self.samples_dir}/{self.dataset_name}_samples_generated.npy", x)
        np.save(f"{self.samples_dir}/{self.dataset_name}_hist_bwd.npy", x_hist.cpu().detach().numpy())

        return