import logging
import models
from sde_lib import SDE_Brownian_manifolds
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
import tensorboardX
from sampling import SDE_sampler_manifolds
from utils import (
    save_model,
    load_model,
    check_memory,
    ExponentialMovingAverage)
import manifolds


class BasicRunner:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config.device)

        """------------------------get directory------------------------"""
        self.workdir = self.config.workdir
        self.savefig_dir = os.path.join(self.workdir, 'figs')
        self.samples_dir = os.path.join(self.workdir, 'samples')
        self.validate_dir = os.path.join(self.workdir, 'validate')
        os.makedirs(self.savefig_dir)
        os.makedirs(self.samples_dir)
        os.makedirs(self.validate_dir)
        self.dataset_name = self.config.problem.dataset

        """------------------------get manifolds------------------------"""
        if self.config.problem.manifold == "S2":
            self.manifold = manifolds.Manifold_Sphere(dim=2)
        elif self.config.problem.manifold == "SOn":
            self.manifold = manifolds.Manifold_SOn(self.config.problem.mat_dim)
        elif self.config.problem.manifold=="SDF":
            self.obj = self.dataset_name.split('_')[0]
            self.sdf_model_path = f'./constraint/model/{self.obj}_whole_sdf.pt'
            mesh_path = f"./data/{self.obj}/{self.obj}_mesh_simple.ply"
            self.manifold = manifolds.Manifold_SDF(model_path=self.sdf_model_path, mesh_path=mesh_path)
            torch.save(self.manifold.model.state_dict(), os.path.join(self.workdir + "/sdf_constraint_model.pt"))
            self.manifold.model.to(self.device)
        elif self.config.problem.manifold == "MD":
            self.manifold = manifolds.Manifold_MD()
        else:
            raise NotImplementedError

        """------------------------get sde------------------------"""
        self.sde = SDE_Brownian_manifolds(sigma_min=self.config.model.sigma_min,
                                          sigma_max=self.config.model.sigma_max,
                                          N=self.config.model.N,
                                          T=self.config.model.T)
        self.plot_sde_info()
        """------------------------others------------------------"""
        self.tb_logger = tensorboardX.SummaryWriter(log_dir=self.workdir)

        self.nll_K = self.config.training.nll_K
        self.nll_bs = self.config.training.nll_bs

    def get_network(self):
        network_mode = self.config.training.network_mode
        if network_mode == "MLP":
            layers = [self.manifold.out_dim + 1] + self.config.training.hidden_layers + [self.manifold.out_dim]
            network = models.MLP(layers, activation=self.config.training.activation)
        elif network_mode == "EMLP":
            layers = [3*self.natom+1] + self.config.training.hidden_layers + [3*self.natom]
            network = models.EMLP(layers, xref=self.xref, activation=self.config.training.activation)
        else:
            raise NotImplementedError
        return network

    def run(self):
        if self.config.if_train:
            self.network = self.get_network()
            self.train_step()
        else:
            model_path = os.path.join(self.workdir, self.config.load_model_path)
            self.network = load_model(model_path)
        save_model(self.workdir, self.network, name="model.pt")

        if self.config.if_test:
            self.test()

        if self.config.if_sample:
            self.sample_on_manifolds()

    def train_step(self):
        logging.info('Start training...')
        logging.info(f"Network:\n {self.network.__str__()}")
        logging.info(f"Param. No.: {sum(p.numel() for p in self.network.parameters() if p.requires_grad)}")

        self.network.to(self.device)
        optimizer = torch.optim.Adam(self.network.parameters(), lr=self.config.optim.lr, weight_decay=0.000,
                          betas=(0.9, 0.999), amsgrad=False)

        training_loader = torch.utils.data.DataLoader(torch.arange(0, self.training_set_path.shape[0]),
                                                      batch_size=self.config.training.batch_size,
                                                      shuffle=True,
                                                      drop_last=True)

        self.total_epochs = self.config.training.n_epochs
        val_freq = self.config.training.val_freq if self.config.training.val_freq > 0 else int(self.total_epochs/20)
        step = 0
        loss_train_list = []
        self.validate(mode='start')
        for epoch in range(self.total_epochs+1):

            for i, sample_indices in enumerate(training_loader):
                step += 1

                samples = self.training_set_path[sample_indices,:].to(self.device)
                loss = self.loss_fn(samples)

                optimizer.zero_grad()
                loss.backward()
                # if not hasattr(self.config.training, 'no_grad_clip'):
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=10.0)
                optimizer.step()

                if epoch > 0:
                    self.ema.update(self.network.parameters())

            "-------------------------------------validate---------------------------------------"
            if epoch == 0:
                self.ema = ExponentialMovingAverage(self.network.parameters(), 0.999)
            
            self.tb_logger.add_scalar(f'loss_{self.dataset_name}', loss, global_step=epoch)
            loss_train_list.append(loss.detach().cpu().numpy())

            if epoch % self.config.training.record_val_nll_freq == 0:
                logging.info(f'epoch: {epoch}/{self.total_epochs}, total step: {step}, loss: {loss.item():.6f}')

            if epoch % val_freq == 0:

                save_model(self.validate_dir, self.network, name=f"model_temp.pt")
                self.ema.store(self.network.parameters())
                self.ema.copy_to(self.network.parameters())
                self.validate(epoch=epoch, step=step, batch=samples[:, 0, :].detach().cpu().clone())
                self.ema.restore(self.network.parameters())

                fig = plt.figure()
                plt.plot(loss_train_list, c='b', label='loss')
                plt.savefig(self.savefig_dir + f"/aa_loss_training.png", dpi=300, bbox_inches='tight')
                plt.close(fig)

            if self.config.training.update_training_set_path_freq > 0 and epoch % self.config.training.update_training_set_path_freq == 0 and epoch > 0:
                del self.training_set_path
                check_memory(keep_quiet=True)
                self.training_set_path = self.generate_path_dataset(self.training_set, keep_quiet=True)
                training_loader = torch.utils.data.DataLoader(torch.arange(0, self.training_set_path.shape[0]), 
                                                              batch_size=self.config.training.batch_size, 
                                                              shuffle=True,
                                                              drop_last=True)

        self.validate(mode='end')
        self.ema.copy_to(self.network.parameters())

    def loss_fn(self, batch):
        # Shape of batch : [bsz, sde.N+1, dim]
        assert batch.shape[1] == self.sde.N+1, "The shape of batch is wrong!!!"

        t = torch.linspace(0., self.sde.T, self.sde.N + 1).to(batch)
        bsz = batch.shape[0]
        delta_t = torch.diff(t)
        _, diffusion = self.sde.sde(None, t)

        std = diffusion[:-1] * torch.sqrt(torch.abs(delta_t))

        std = std[None, :].repeat(bsz, 1).flatten(start_dim=0, end_dim=1)
        x_t = batch[:, 1:, :].flatten(start_dim=0, end_dim=1)
        diff_vec = torch.diff(batch, dim=1).flatten(start_dim=0, end_dim=1)
        tangent_vec = self.manifold.project_onto_tangent_space(diff_vec, base_point=x_t)
        t_temp = t[None, 1:].repeat(bsz, 1).flatten(start_dim=0, end_dim=1)

        b = self.sde.func_b(x_t).to(batch)

        scores = self.manifold.project_onto_tangent_space(self.network(x_t, t_temp), base_point=x_t)

        if len(tangent_vec.shape) == 2:
            target = - tangent_vec / std.reshape(-1, 1)**2 + b
            losses = 0.5 * torch.sum(scores * scores, dim=-1) - (scores * target).sum(dim=-1)
        else:
            target = - tangent_vec / std.reshape(-1, 1, 1) ** 2 + b
            losses = 0.5 * torch.sum(scores * scores, dim=(1, 2)) - (scores * target).sum(dim=(1, 2))

        loss = (losses * std ** 2).reshape(bsz, self.sde.N).sum(dim=1)
        return torch.mean(loss)

    def plot_sde_info(self):
        Min_std = self.sde.g_0*np.sqrt(self.sde.T/self.sde.N)
        Max_std = self.sde.g_T*np.sqrt(self.sde.T/self.sde.N)
        logging.info(f"Min std when discretize the sde: {Min_std.numpy().item() :.4f}")
        logging.info(f"Max std when discretize the sde: {Max_std.numpy().item() :.4f}")
    
        fig = plt.figure(figsize=(10, 5))
        
        t = torch.linspace(0, self.sde.T, 1001)
        g_t = self.sde.sde(None, t)[1]
        ax = fig.add_subplot(121)
        ax.plot(t, g_t)
        ax.set_xlabel('t')
        ax.set_ylabel('g_t')

        k = torch.arange(0, self.sde.N, dtype=torch.int64)
        sigma = self.sde.sde(None, k * self.sde.dt)[1] * np.sqrt(self.sde.dt)
        ax = fig.add_subplot(122)
        ax.plot(k, sigma)
        ax.set_xlabel('k')
        ax.set_ylabel('sigma_k')

        plt.savefig(self.savefig_dir + "/aa_sde_info.png", dpi=300, bbox_inches='tight')
        plt.close(fig)

    def generate_path_dataset(self, data_init, keep_quiet=False):
        """
        The network is not needed.
        x_hist: (N+1, bsz, dim)
        return: (bsz, N+1, dim)
        """
        if keep_quiet is False: 
            logging.info("-------------------------Start generating path dataset.-------------------------")
        device = self.device
        init = data_init.to(device)
        if hasattr(self.manifold, "model"): self.manifold.model.to(device)
        x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init,
                                                      reverse=False,
                                                      keep_quiet=keep_quiet)
        return x_hist.transpose(0, 1)


    @torch.no_grad()
    def negative_log_likelihood_fn(self, data, keep_quiet=True, return_mean=True):
        """
        x_t - >x_{n+1}, x_s -> x_n
        std_f: sigma in forward process
        std_b: sigma in backward process
        """

        def mean_with_mask(data, mask):
            """
            shape of data and mask: [nll_bsz, self.nll_K]
            calculate the mean along dim=1 with mask
            """
            assert data.shape == mask.shape, "data and mask should have the same shape!"
            assert len(data.shape) == 2

            idx = mask.sum(dim=1) > 0

            if return_mean:
                # may be: temp.shape[0] < nll_bsz
                temp = data[idx, :].sum(dim=1) / mask[idx, :].sum(dim=1)
                return temp
            else:
                # temp.shape[0] = nll_bsz always holds
                temp = torch.zeros(data.shape[0]).to(data)
                temp[idx] = data[idx, :].sum(dim=1) / mask[idx, :].sum(dim=1)
                return temp

        device = self.device
        data = data.to(device)
        self.network.to(device)
        assert len(data.shape) == 2, "The shape of data is wrong!!!"
        nll_bsz = self.nll_bs
        data_loader = torch.utils.data.DataLoader(torch.arange(0, data.shape[0]), batch_size=nll_bsz)

        t = torch.linspace(0., self.sde.T, self.sde.N + 1).to(device)
        delta_t = torch.diff(t)
        _, diffusion = self.sde.sde(torch.zeros(1).to(device), t)
        C_0 = torch.tensor(self.manifold.log_volume()).float().to(device)

        for i, sample_indices in enumerate(data_loader):
            samples = torch.repeat_interleave(data[sample_indices], repeats=self.nll_K, dim=0)
            
            x, x_hist, other_dict = SDE_sampler_manifolds(self.sde, self.manifold, init=samples,
                                                          reverse=False,
                                                          keep_quiet=keep_quiet)

            # Shape of x_hist: [sde.N, nll_bsz * self.nll_K, dim]; Shape of mask: [nll_bsz, self.nll_K]
            mask = other_dict["converged_traj"].reshape(sample_indices.shape[0], self.nll_K).float()
            x_hist = other_dict["x_hist_all"]

            bsz = x_hist.shape[1] # nll_bsz * self.nll_K
            
            std_f = diffusion[:-1] * torch.sqrt(torch.abs(delta_t))
            std_f = std_f[:, None].repeat(1, bsz).flatten(start_dim=0, end_dim=1)
            std_b = diffusion[:-1] * torch.sqrt(torch.abs(delta_t))
            std_b = std_b[:, None].repeat(1, bsz).flatten(start_dim=0, end_dim=1)

            x_t = x_hist[1:].flatten(start_dim=0, end_dim=1)
            x_s = x_hist[:-1].flatten(start_dim=0, end_dim=1)

            v = self.manifold.project_onto_tangent_space(x_t - x_s, base_point=x_s)
            t_temp = t[1:, None].repeat(1, bsz).flatten(start_dim=0, end_dim=1)
            v_prime_1 = self.manifold.project_onto_tangent_space((x_s - x_t)/std_b.reshape(-1, 1), base_point=x_t)
            v_prime_2 = self.manifold.project_onto_tangent_space(std_b.reshape(-1, 1) * self.network(x_t, t_temp).detach(), base_point=x_t)

            loss_part_temp = 0.5 * torch.sum(v_prime_2 * v_prime_2, dim=-1) - (v_prime_2 * v_prime_1).sum(dim=-1)
            loss_part_batch = loss_part_temp.reshape(self.sde.N, sample_indices.shape[0], self.nll_K).sum(dim=0)
            C_1_temp = 0.5 * v_prime_1.norm(dim=1)**2
            C_1_batch = C_1_temp.reshape(self.sde.N, sample_indices.shape[0], self.nll_K).sum(dim=0)
            C_2_temp = v.norm(dim=1)**2/(2*std_f**2)
            C_2_batch = C_2_temp.reshape(self.sde.N, sample_indices.shape[0], self.nll_K).sum(dim=0)

            tmp = loss_part_batch + C_1_batch - C_2_batch
            tmp = torch.clamp(tmp, max=100).double()
            temp = torch.exp(-tmp) # shape: [nll_bsz, self.nll_K]

            temp1 = mean_with_mask(temp, mask)

            if (not return_mean) and self.config.calculate_mesh_nll_bunny_spot:
                nll_K_batch = torch.zeros_like(temp1).float()
                idx = (temp1 >= 0.9 * torch.exp(torch.tensor([-100.])).to(device))
                if (~idx).sum() > 0:
                    logging.info("Warning: zero in temp1 is replaced by NAN!")
                nll_K_batch[idx] = C_0 - torch.log(temp1[idx]).float()
                nll_K_batch[~idx] = torch.nan
            else:
                nll_K_batch = C_0 - torch.log(temp1).float()

            loss_part = torch.cat((loss_part, mean_with_mask(loss_part_batch, mask))) if i > 0 else mean_with_mask(loss_part_batch, mask)
            C_1 = torch.cat((C_1, mean_with_mask(C_1_batch, mask))) if i > 0 else mean_with_mask(C_1_batch, mask)
            C_2 = torch.cat((C_2, mean_with_mask(C_2_batch, mask))) if i > 0 else mean_with_mask(C_2_batch, mask)
            nll_K = torch.cat((nll_K, nll_K_batch)) if i > 0 else nll_K_batch

            if (not return_mean) and (self.config.problem.manifold=="SDF"):
                logging.info(nll_K.shape)

        nll_upper_bound = loss_part + C_0 + C_1 - C_2
        if return_mean:
            return nll_K.mean(), nll_upper_bound.mean(), loss_part.mean(), C_1.mean(), C_2.mean()
        else:
            return nll_K, nll_upper_bound, loss_part, C_1, C_2

    def calculate_constrain(self, samples):
        samples_constrain = self.manifold.constrain_fn(samples)
        logging.info(f'Generated samples constrain: min: {samples_constrain.min().item():.6f}, max: {samples_constrain.max().item():.6f}')


