import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving figures
import logging
import models
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import tensorboardX
import manifolds
import sys
import time 

from sampling import SDE_sampler_manifolds_CLangevin, SDE_sampler_manifolds_OLLA, SDE_sampler_manifolds_CHMC_OBABO, SDE_sampler_manifolds_CHMC_OABOA, SDE_sampler_manifolds_ULLA_OABOA, SDE_sampler_manifolds_CHMC_EM, SDE_sampler_manifolds_ULLA_EM
from sde_lib import SDE_Brownian_manifolds
from utils import ExponentialMovingAverage, save_model, load_model, check_memory, get_temperature
from loss_utils import loss_overdamped_path, loss_underdamped_path_OBABO, loss_underdamped_path_OABOA, loss_underdamped_path_EM, nll_overdamped_path, nll_underdamped_path_OBABO, nll_underdamped_path_EM

# Add the parent directory to the path to allow importing 'constraints'
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from constraints import get_constraint_functions


class BasicRunner:
    def __init__(self, config):
        self.config = config  # store the configuration
        self.device = config.device  # set the device (CPU or GPU)

        # Get directory for saving samples
        self.workdir = self.config.workdir
        self.savefig_dir = os.path.join(self.workdir, 'figs')
        self.samples_dir = os.path.join(self.workdir, 'samples')
        self.validate_dir = os.path.join(self.workdir, 'validate')
        os.makedirs(self.savefig_dir, exist_ok=True)
        os.makedirs(self.samples_dir, exist_ok=True)
        os.makedirs(self.validate_dir, exist_ok=True)
        self.dataset_name = self.config.problem.dataset  # dataset name from the configuration

        self.trajectory_gen_time = 0.0
        self.cumulative_training_time = 0.0

        # Get manifold information
        if self.config.problem.manifold == 'S2':
            self.manifold = manifolds.Manifold_Sphere(dim=2)
        elif self.config.problem.manifold == "SOn":
            self.manifold = manifolds.Manifold_SOn(self.config.problem.mat_dim)
        elif self.config.problem.manifold=="SDF":
            self.obj = self.dataset_name.split('_')[0]
            self.sdf_model_path = f'./constraint/model/{self.obj}_whole_sdf.pt'
            mesh_path = f"./data/{self.obj}/{self.obj}_mesh_simple.ply"
            self.manifold = manifolds.Manifold_SDF(model_path=self.sdf_model_path, mesh_path=mesh_path)
            torch.save(self.manifold.model.state_dict(), os.path.join(self.workdir + "/sdf_constraint_model.pt"))
            self.manifold.model.to(self.device)
        elif self.config.problem.manifold == "general":
            # Dynamically get constraint functions based on the dataset name
            h_func, g_func = get_constraint_functions(self.dataset_name)
            self.manifold = manifolds.Manifold_general(
                dim=self.config.problem.dim,
                m=self.config.problem.m,
                l=self.config.problem.l,
                h=h_func,
                g=g_func
            )
        elif self.config.problem.manifold == "MD":
            self.manifold = manifolds.Manifold_MD(psi_windows=[(self.config.problem.psi_windows_low,
                                                   self.config.problem.psi_windows_high)],
                                                   boundary_repulsion=self.config.sample.epsilon)
        elif self.config.problem.manifold == "Robot":
            self.manifold = manifolds.Manifold_Robot(time_steps = self.config.problem.time_steps,
                                                     target_ee_z = self.config.problem.target_ee_z,
                                                     obstacles_info = self.config.problem.obstacles_info,
                                                     safety_margin= self.config.problem.safety_margin,
                                                     obstacle_radius= self.config.problem.obstacle_radius,
                                                     boundary_repulsion_rate= self.config.sample.epsilon)

        else:
            raise NotImplementedError(f"Manifold {self.config.problem.manifold} is not implemented.")



        # Get SDE
        self.sde = SDE_Brownian_manifolds(
            sigma_min=self.config.model.sigma_min_overdamped if self.config.sample.sampler in ('CLangevin', 'OLLA') else self.config.model.sigma_min_underdamped,
            sigma_max=self.config.model.sigma_max_overdamped if self.config.sample.sampler in ('CLangevin', 'OLLA') else self.config.model.sigma_max_underdamped,
            tau_min=self.config.model.tau_min_overdamped if self.config.sample.sampler in ('CLangevin', 'OLLA') else self.config.model.tau_min_underdamped,
            tau_max=self.config.model.tau_max_overdamped if self.config.sample.sampler in ('CLangevin', 'OLLA') else self.config.model.tau_max_underdamped,
            N=self.config.model.N,
            T=self.config.model.T_overdamped if self.config.sample.sampler in ('CLangevin', 'OLLA') else self.config.model.T_underdamped,
            sampler = self.config.sample.sampler,
            drift_mode = self.config.sample.drift_mode if hasattr(self.config.sample, 'drift_mode') else 'zero'  # Default to 'zero' if not specified
        )
        
        # Create the SDE object based on the configuration
        self.plot_sde_info()

        # Get extras
        self.tb_logger = tensorboardX.SummaryWriter(log_dir = self.workdir)  # TensorBoard logger for visualization
        self.nll_K = self.config.training.nll_K  # Number of Langevin steps for NLL estimation
        self.nll_bs = self.config.training.nll_bs  # Batch size for NLL estimation

        self.get_temperature = lambda epoch: get_temperature(epoch, self.config.training.n_epochs, T_min=self.config.training.T_min, T_max=self.config.training.T_max, T_min_cutoff=self.config.training.T_min_cutoff, T_max_cutoff=self.config.training.T_max_cutoff, mode="cosine")

        # Sampler specific initialization
        if self.config.sample.sampler == 'CLangevin':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_CLangevin
            self.sde_kwargs = {}
            logging.info("Using CLangevin sampler")

        elif self.config.sample.sampler == 'OLLA':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_OLLA
            self.sde_kwargs = {'alpha': self.config.sample.sampler_OLLA_alpha}
            logging.info(f"Using OLLA sampler with alpha = {self.config.sample.sampler_OLLA_alpha}")

        elif self.config.sample.sampler == 'CHMC_OBABO':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_CHMC_OBABO
            logging.info("Using CHMC sampler with 2nd order approximation")
            self.sde_kwargs = {}
            self.sde_kwargs = dict(mass=self.config.sample.sampler_CHMC_mass, gamma=self.config.sample.sampler_CHMC_gamma)
        elif self.config.sample.sampler == 'CHMC_OABOA':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_CHMC_OABOA
            logging.info("Using CHMC sampler with 2nd order approximation")
            self.sde_kwargs = {}
            self.sde_kwargs = dict(mass=self.config.sample.sampler_CHMC_mass, gamma=self.config.sample.sampler_CHMC_gamma)

        elif self.config.sample.sampler == 'CHMC_EM':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_CHMC_EM
            logging.info("Using CHMC sampler with Exact Manifold MCMC")
            self.sde_kwargs = {}
            self.sde_kwargs = dict(mass=self.config.sample.sampler_CHMC_mass, gamma=self.config.sample.sampler_CHMC_gamma)

        elif self.config.sample.sampler == 'ULLA_OABOA':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_ULLA_OABOA
            logging.info("Using ULLA sampler with OABOA")
            self.sde_kwargs = {}
            self.sde_kwargs = dict(mass=self.config.sample.sampler_CHMC_mass, gamma=self.config.sample.sampler_CHMC_gamma, alpha=self.config.sample.sampler_ULLA_alpha)
        elif self.config.sample.sampler == 'ULLA_EM':
            self.SDE_sampler_manifolds = SDE_sampler_manifolds_ULLA_EM
            logging.info("Using ULLA sampler with EM")
            self.sde_kwargs = {}
            self.sde_kwargs = dict(mass=self.config.sample.sampler_CHMC_mass, gamma=self.config.sample.sampler_CHMC_gamma, alpha=self.config.sample.sampler_ULLA_alpha)
        else:
            raise NotImplementedError(f"Sampler {self.config.sample.sampler} is not implemented.")
        
        if self.config.sample.sampler in ('CLangevin', 'OLLA'):
            self.negative_log_likelihood_fn = lambda data, return_mean = True: nll_overdamped_path(data, self.network, self.sde, self.manifold,
                                                                                                    self.SDE_sampler_manifolds,
                                                                                                    nll_bs=self.nll_bs, nll_K=self.nll_K,
                                                                                                    device=self.device,
                                                                                                    sde_kwargs=self.sde_kwargs,
                                                                                                    keep_quiet=True, return_mean=return_mean)
        elif self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA']:
            self.negative_log_likelihood_fn = lambda data, return_mean = True: nll_underdamped_path_OBABO(data, self.network, self.sde, self.manifold,
                                                                                                            self.SDE_sampler_manifolds,
                                                                                                            nll_bs=self.nll_bs, nll_K=self.nll_K,
                                                                                                            device=self.device,
                                                                                                            sde_kwargs=self.sde_kwargs,
                                                                                                            keep_quiet=True, return_mean=return_mean)

        elif self.config.sample.sampler in ['CHMC_EM', 'ULLA_EM']:
            self.negative_log_likelihood_fn = lambda data, return_mean = True: nll_underdamped_path_EM(data, self.network, self.sde, self.manifold,
                                                                                                        self.SDE_sampler_manifolds,
                                                                                                        nll_bs=self.nll_bs, nll_K=self.nll_K,
                                                                                                        device=self.device,
                                                                                                        sde_kwargs=self.sde_kwargs,
                                                                                                        keep_quiet=True, return_mean=return_mean)

        else:
            raise NotImplementedError(f"Sampler {self.config.sample.sampler} is not implemented for NLL computation.")
                                    
    def get_network(self):
        network_mode = self.config.training.network_mode  # Get the network mode from the configuration
        if network_mode == 'MLP':
            if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'CHMC_EM', 'ULLA_OABOA', 'ULLA_EM']:
                network_input_dim = 2 * self.manifold.out_dim + 1
            else:
                network_input_dim = self.manifold.out_dim + 1

            layers = [network_input_dim] + self.config.training.hidden_layers + [self.manifold.out_dim]

            network = models.MLP(layers, activation=self.config.training.activation)
            
        elif network_mode == 'EMLP':
            if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'CHMC_EM', 'ULLA_OABOA', 'ULLA_EM']:
                layers = [6*self.natom+1] + self.config.training.hidden_layers + [3*self.natom]
            else:
                layers = [3*self.natom+1] + self.config.training.hidden_layers + [3*self.natom]
            network = models.EMLP(layers, xref=self.xref, activation=self.config.training.activation)
        else:
            raise NotImplementedError(f"Network mode {network_mode} is not implemented.")
        return network


    def run(self):
        if self.config.if_train:
            self.network = self.get_network()  # Initialize the network
            self.train_step()
        else:
            model_path = os.path.join(self.workdir, self.config.load_model_path)
            self.network = load_model(self.get_network(), model_path) # Load the pre-trained model
        save_model(self.workdir, self.network, name = 'model.pt')  # Save the model after training or loading

        if self.config.if_test:
            self.test()

        if self.config.if_sample:
            self.sample_on_manifolds()  # Sample from the trained model on the manifold

    def train_step(self):
        logging.info('Start training...')
        logging.info(f"Network:\n {self.network.__str__()}")
        logging.info(f"Numbers of parameters: {sum(p.numel() for p in self.network.parameters() if p.requires_grad)}")

        self.network.to(self.device)  # Move the network to the specified device
        optimizer = torch.optim.Adam(self.network.parameters(), lr=self.config.optim.lr, weight_decay = 0.000, betas = (0.9, 0.999), amsgrad=False)  # Initialize the optimizer

        training_loader = torch.utils.data.DataLoader(torch.arange(0, self.training_set_path.shape[0]),
                                                        batch_size=self.config.training.batch_size,
                                                        shuffle=True,
                                                        drop_last=True)
        self.total_epochs = self.config.training.n_epochs  # Total number of epochs for training
        val_freq = self.config.training.val_freq if self.config.training.val_freq > 0 else int(self.total_epochs / 20)  # Validation frequency
        step = 0
        loss_train_list = []
        self.validate(mode = 'start')
        for epoch in range(self.total_epochs + 1):
            train_star_time = time.time()
            for i, sample_indices in enumerate(training_loader):
                step += 1

                samples = self.training_set_path[sample_indices,:].to(self.device)  # Get a batch of training samples
                loss = self.loss_fn(samples, epoch)  # Compute the loss

                optimizer.zero_grad()  # Zero the gradients
                loss.backward()  # Backpropagate the loss
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm = 10.0)  # Clip gradients to prevent exploding gradients
                optimizer.step()  # Update the model parameters

                if epoch > 0:
                    self.ema.update(self.network.parameters())  # Update the Exponential Moving Average of the network parameters
            self.cumulative_training_time += time.time() - train_star_time
            
            # Validate
            if epoch == 0 :
                self.ema = ExponentialMovingAverage(self.network.parameters(), self.config.optim.ema) # Initialize EMA

            loss_train_list.append(loss.detach().cpu().numpy())

            if epoch % val_freq == 0: # use shadow params only for validation o/w use current params
                save_model(self.validate_dir, self.network, name =f"model_temp.pt")
                self.ema.store(self.network.parameters())  # Store the current parameters in EMA (collected params)
                self.ema.copy_to(self.network.parameters())  # Copy the EMA parameters to the network (shadow params)

                self.validate(epoch = epoch, step = step, batch = samples[:, 0, :self.manifold.out_dim].detach().cpu().clone())
                self.ema.restore(self.network.parameters())  # Restore the original parameters from EMA

                # Plot training loss and constraint violation
                fig, axes = plt.subplots(1, 2, figsize=(12, 5))

                # (1) Training loss on the left
                epochs_loss = np.arange(len(loss_train_list))
                axes[0].plot(epochs_loss, loss_train_list, label='Training Loss')
                axes[0].set_xlabel('Epoch')
                axes[0].set_ylabel('Training Loss')
                axes[0].grid(True)
                axes[0].legend()

                # (2) Constraint‐violation h_val on the right
                if hasattr(self, 'h_val') and len(self.h_val) > 0:
                    h_vals = [h.mean() if isinstance(h, np.ndarray) else h for h in self.h_val]
                    record_step = (self.config.training.update_training_set_path_freq or val_freq)
                    epochs_h = np.arange(len(h_vals)) * record_step
                    axes[1].plot(epochs_h, h_vals, label='Constraint Violation')
                else:
                    axes[1].text(0.5, 0.5, 'no data', ha='center', va='center')
                axes[1].set_xlabel('Epoch')
                axes[1].set_ylabel('Constraint Violation')
                axes[1].grid(True)
                axes[1].legend()

                plt.savefig(self.savefig_dir + f"/aa_loss_constraint_{self.dataset_name}.png",
                            dpi=300, bbox_inches='tight')
                plt.close(fig)

                
            if self.config.training.update_training_set_path_freq > 0 and epoch % self.config.training.update_training_set_path_freq == 0 and epoch > 0:
                del self.training_set_path
                check_memory(keep_quiet=True)
                self.training_set_path, h_val = self.generate_path_dataset(self.training_set, keep_quiet=True)

                training_loader = torch.utils.data.DataLoader(torch.arange(0, self.training_set_path.shape[0]), 
                                                                batch_size=self.config.training.batch_size, 
                                                                shuffle=True,
                                                                drop_last=True)
                self.h_val = self.h_val + [h_val.mean().detach().cpu().item()] if hasattr(self, 'h_val') else [h_val.mean().detach().cpu().item()] 
        

        self.validate(mode='end')
        self.ema.copy_to(self.network.parameters())

    def loss_fn(self, path_batch, epoch = None):
        data_hist = path_batch.transpose(0, 1).contiguous()
        sigmas = self.sde.sde(None, torch.linspace(0., self.sde.T, self.sde.N+1, device=data_hist.device))[1][:-1]
        taus = self.sde.get_tau_scheduler(torch.linspace(0., self.sde.T, self.sde.N+1, device=data_hist.device))[:-1]

        if self.config.sample.sampler in ('CLangevin', 'OLLA'):
            loss = loss_overdamped_path(
                        self.manifold, data_hist,
                        score_net = self.network,
                        func_b    = self.sde.func_b,
                        sigmas    = sigmas,
                        dt        = self.sde.dt)
        elif self.config.sample.sampler == 'CHMC_OBABO':
            loss = loss_underdamped_path_OBABO(
                        self.manifold, data_hist,
                        score_net = self.network,
                        func_b = self.sde.func_b,
                        sigmas   = sigmas,
                        dt     = self.sde.dt,
                        mass   = self.config.sample.sampler_CHMC_mass,
                        gamma  = self.config.sample.sampler_CHMC_gamma
                        )

        elif self.config.sample.sampler in ['CHMC_OABOA', 'ULLA_OABOA']:
            loss = loss_underdamped_path_OABOA(
                        self.manifold, data_hist,
                        score_net = self.network,
                        func_b = self.sde.func_b,
                        sigmas   = sigmas,
                        taus = taus,
                        dt     = self.sde.dt,
                        mass   = self.config.sample.sampler_CHMC_mass,
                        gamma  = self.config.sample.sampler_CHMC_gamma
                        )
        elif self.config.sample.sampler in ['CHMC_EM', 'ULLA_EM']:
            loss = loss_underdamped_path_EM(
                        self.manifold, data_hist,
                        score_net = self.network,
                        func_b = self.sde.func_b,
                        sigmas   = sigmas,
                        dt     = self.sde.dt,
                        mass   = self.config.sample.sampler_CHMC_mass,
                        gamma  = self.config.sample.sampler_CHMC_gamma,
                        temp   = self.get_temperature(epoch)
                        )
        
        else:
            raise NotImplementedError
        return loss

    def plot_sde_info(self):
        min_std = self.sde.g_0 * np.sqrt(self.sde.T / self.sde.N)
        max_std = self.sde.g_T * np.sqrt(self.sde.T / self.sde.N)
        logging.info(f"Minimum std: {min_std.numpy().item():.4f}, Maximum std: {max_std.numpy().item():.4f} when discretized")

        fig = plt.figure(figsize = (10, 5))

        t = torch.linspace(0, self.sde.T, 1001)
        g_t = self.sde.sde(None, t)[1]
        ax = fig.add_subplot(121)
        ax.plot(t, g_t, label='Diffusion Coefficient $g(t)$')
        ax.set_xlabel('Time $t$')
        ax.set_ylabel('$g(t)$')
        ax.grid(True)
        ax.legend()

        k = torch.arange(0, self.sde.N, dtype = torch.int64)
        sigma = self.sde.sde(None, k * self.sde.dt)[1] * np.sqrt(self.sde.dt)

        ax = fig.add_subplot(122)
        ax.plot(k, sigma, label='Discretized Diffusion Coefficient $g_k$')
        ax.set_xlabel('Time Step $k$')
        ax.set_ylabel('$g_k$')
        ax.grid(True)
        ax.legend()
        
        fig.suptitle(f"SDE Info: $\\sigma_{{min}}$ = {self.sde.sigma_min:.4f}, $\\sigma_{{max}}$ = {self.sde.sigma_max:.4f}, N = {self.sde.N}, T = {self.sde.T} - g_k = $g(t_k) \\sqrt{{\\Delta t}}$",)

        plt.savefig(self.savefig_dir + "/aa_sde_info.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

        if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA']:
            delta_t = self.sde.dt
            mass = self.config.sample.sampler_CHMC_mass
            gamma = self.config.sample.sampler_CHMC_gamma
            
            k = torch.arange(0, self.sde.N, dtype=torch.int64)
            diffusion = self.sde.sde(None, k * self.sde.dt)[1]
            a_k = torch.exp(-diffusion**2 * delta_t * gamma / (4. * mass))
            
            fig, ax = plt.subplots(figsize=(8, 5))
            ax.plot(k, a_k, label='$a_k$ values for CHMC')
            ax.set_xlabel('Time Step $k$')
            ax.set_ylabel('$a_k$')
            ax.grid(True)
            ax.legend()
            ax.set_title(f'CHMC $a_k$ values: mass={mass}, gamma={gamma}')
            
            plt.savefig(self.savefig_dir + "/aa_chmc_ak_values.png", dpi=300, bbox_inches='tight')
            plt.close(fig)


    def generate_path_dataset(self, data_init, keep_quiet=False):
        if not keep_quiet:
            logging.info("-------------------------Start generating path dataset.-------------------------")
        
        sampling_start_time = time.time()
        
        device = self.device
        x_init = data_init.to(device)
        
        _, data_hist, _ = self.SDE_sampler_manifolds(
            self.sde, self.manifold, x_init,
            reverse=False,
            keep_quiet=keep_quiet, **self.sde_kwargs
        )
        
        sampling_end_time = time.time()
        self.trajectory_gen_time += sampling_end_time - sampling_start_time

        if not keep_quiet:
            logging.info(f"Forward sampling time: {sampling_end_time - sampling_start_time:.2f} seconds.")

        if self.config.sample.sampler in ['CHMC_OBABO', 'CHMC_OABOA', 'ULLA_OABOA', 'CHMC_EM']:
            x_hist = data_hist[..., :self.manifold.out_dim]
        else:
            x_hist = data_hist

        x_hist_flat = x_hist.reshape(-1, x_hist.shape[-1])

        h_val = self.manifold.constrain_fn(x_hist_flat)

        return data_hist.detach().transpose(0, 1), h_val.detach()

    def calculate_constrain(self, samples):
        samples_constrain = self.manifold.constrain_fn(samples)
        logging.info(f'Generated samples constrain: min: {samples_constrain.min().item():.6f}, max: {samples_constrain.max().item():.6f}')

    def plot_frobenius_norm(self, epoch):
        if not hasattr(self, 'frobenius_norms'):
            return

        plt.figure()
        plt.plot(range(len(self.frobenius_norms)), self.frobenius_norms)
        plt.xlabel("Validation Epoch")
        plt.ylabel("Frobenius Norm of Jacobian")
        plt.title("Frobenius Norm of Jacobian vs. Validation Epoch")
        save_path = os.path.join(self.savefig_dir, f"frobenius_norm_plot_epoch_{epoch}.png")
        plt.savefig(save_path)
        plt.close()
