import numpy as np
import torch
from torch import nn
from scipy.stats import multivariate_normal
from scipy.spatial import ConvexHull
from mpl_toolkits.axes_grid1 import make_axes_locatable
import ot

from pdPINN.model.model_utiities import divergence, gradient, custom_hessian, ST
from pdPINN.model.modules import NormalizeLayer, InverseStandardizeLayer, FCLayer
from pdPINN.model.sampler import *
from pdPINN.model.siren import to_numpy, to_tensor, batch_predict2, dataloader_from_np

import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsRegressor


class MassCons2d(nn.Module):
    mse = nn.MSELoss()

    def __init__(self,
                 X, Y,
                 tmin, tmax,
                 xmin, xmax,
                 ymin, ymax,
                 hidden_features=64,
                 vf_hidden_features=64,
                 density_hidden_layers=3,
                 vf_hidden_layers=3,
                 nonlinearity='relu',
                 **kwargs
                 ):
        super().__init__()

        self.in_features = X.shape[1]

        self.hidden_features = hidden_features
        self.sobol_engine = torch.quasirandom.SobolEngine(dimension=self.in_features, scramble=True)

        self.tmin, self.tmax = tmin, tmax
        self.xmin, self.xmax = xmin, xmax
        self.ymin, self.ymax = ymin, ymax

        self.min_extent = torch.nn.Parameter(torch.tensor([self.tmin, self.xmin, self.ymin], requires_grad=False))
        self.max_extent = torch.nn.Parameter(torch.tensor([self.tmax, self.xmax, self.ymax], requires_grad=False))

        self.normalize_layer = torch.jit.script(NormalizeLayer.from_data(X=X, ignore_dims=[], trainable=False))
        self.standardize_mass_layer = torch.jit.script(InverseStandardizeLayer(X=Y[..., [0]], trainable=True, no_shift=False))
        self.standardize_vf_layer = torch.jit.script(InverseStandardizeLayer(X=Y[..., 1:], trainable=False))

        self.density_net = torch.jit.script(FCLayer(in_features=hidden_features, out_features=1,
                                   hidden_features=hidden_features, num_hidden_layers=density_hidden_layers,
                                   outermost_linear=True, outermost_positive=False,
                                   nonlinearity=nonlinearity,
                                   sine_frequency=kwargs.get("sine_frequency", None)))

        self.velocity_net = torch.jit.script(FCLayer(in_features=hidden_features, out_features=Y.shape[1] - 1,
                                    outermost_linear=True,
                                    outermost_positive=False,
                                    hidden_features=vf_hidden_features, num_hidden_layers=vf_hidden_layers,
                                    weight_init=None,
                                    nonlinearity=nonlinearity,
                                    sine_frequency=kwargs.get("sine_frequency", None)
                                    ))

        self.network = torch.jit.script(FCLayer(in_features=self.in_features, out_features=hidden_features,
                               hidden_features=hidden_features, num_hidden_layers=1,
                               outermost_linear=False, outermost_positive=False,
                               sine_frequency=kwargs.get("sine_frequency", None),
                               nonlinearity=nonlinearity))

        self.density_sd = torch.nn.Parameter(torch.exp(torch.FloatTensor([0.])) + 1e-4,
                                             requires_grad=True)
        self.velocity_sd = torch.nn.Parameter(torch.exp(torch.FloatTensor([0., 0.])) + 1e-4,
                                              requires_grad=True)
        self.velocity_cov_tril = torch.nn.Parameter(torch.tril(torch.rand((2, 2)) + torch.eye(2)).detach())

        self.mean = torch.nn.Parameter(torch.tensor(np.mean(X[..., 1:], 0), dtype=torch.float32), requires_grad=False)
        self.cov = torch.nn.Parameter(torch.tensor(np.cov(X[..., 1:].T), dtype=torch.float32),
                                      requires_grad=False)  # .reshape(-1, 1)

        self.fraction_mcmc = .5  # 0.8
        self.mcmc_noiselevel = kwargs.get("mcmc_noiselevel", 1e-4)

        self._periodicity = torch.nn.Parameter(torch.tensor(0., dtype=torch.float32) ** 2, requires_grad=True)
        self.num_samples = kwargs.get("n_samples_constraints")
        self.rar = kwargs.get("rar")
        self.ot_rar = kwargs.get("ot_rar")
        assert not (self.rar and self.ot_rar), "Only one adaptive refinement method at once allowed."

        unique_x = np.unique(X[..., 1:], axis=0)
        hull = ConvexHull(unique_x)
        self.convex_hull_tensor = torch.nn.Parameter(to_tensor(unique_x[hull.vertices], device=self.device),
                                                     requires_grad=False)
        self.convex_hull_np = unique_x[hull.vertices]
        self.bounding_act = torch.nn.Hardtanh(min_val=-1e-5, max_val=1e12)

    def _compile(self, X):
        self.dir = torch.distributions.Dirichlet(to_tensor(np.ones(self.convex_hull_np.shape[0]),
                                                           device=self.device)).expand((self.num_samples,))
        self.prior_sp, self.prior_torch, self.prior_torch_space = self.estimate_gaussian_from_data(X, scale_cov=0.)
        self.prior_sp2d, self.prior_torch2d, self.prior_torch_space2d = self.estimate_gaussian_from_data(X[..., 1:],
                                                                                                         scale_cov=0.5)

        def proposal(num_samples):
            time = np.random.uniform(self.tmin, self.tmax, num_samples)
            space = self.prior_sp2d.rvs(num_samples)
            samples = np.concatenate([time[..., np.newaxis], space], -1)
            return samples

        self.num_mcmc_chains = 500
        self.sampler = LocalSampler(log_p_unnormalized=self._log_p,
                                    # log_dp_unnormalized=self.gradient,
                                    num_chains=self.num_mcmc_chains,
                                    data_dim=X.shape[1],
                                    eps=0.05,
                                    prior_sample=proposal,  # self.prior_sp.rvs,
                                    prior_logprob=self.prior_sp.logpdf
                                    )

        self.log_prior_max = torch.nn.Parameter(
            self.prior_torch.log_prob(self.prior_torch.mean),
            requires_grad=False
        )
        self.periodicity = self._periodicity

    def estimate_gaussian_from_data(self, X: np.array, scale_cov: float = 5):
        mu_prior = np.average(X, 0)  # , weights=Y[:, 0])
        cov_prior = np.cov(X, rowvar=False) * (
                scale_cov * np.eye(X.shape[1]) + np.ones(X.shape[1]))  # , aweights=Y[:, 0])

        prior_sp = multivariate_normal(mean=mu_prior, cov=cov_prior)
        prior_torch = torch.distributions.MultivariateNormal(loc=to_tensor(mu_prior, device=self.device),
                                                             covariance_matrix=to_tensor(cov_prior, device=self.device))
        prior_torch_space = torch.distributions.MultivariateNormal(loc=to_tensor(mu_prior[1:], device=self.device),
                                                                   covariance_matrix=to_tensor(cov_prior[1:, 1:],
                                                                                               device=self.device))
        return prior_sp, prior_torch, prior_torch_space

    @property
    def device(self):
        return next(self.parameters()).device

    def _log_p(self, x):
        return to_numpy(self.sample_log_prob(to_tensor(x, dtype=torch.float32, device=self.device)))

    def sample_log_prob(self, x):

        rho = self.density(x, normalize_inputs=True) * ((x[..., 0] >= self.tmin) * (x[..., 0] <= self.tmax)).reshape(-1,
                                                                                                                     1)
        # weight = torch.exp(self.prior_torch_space.log_prob(x[..., 1:]) - self.log_prior_max).reshape(-1, 1)
        # weight = torch.exp(self.prior_torch.log_prob(x) - self.log_prior_max).reshape(-1, 1)

        return torch.log(rho + 1e-10)

    def velocity_branch(self, coords):
        return self.velocity_net(self.network(coords))

    def density_branch(self, coords):
        return self.density_net(self.network(coords))

    def velocity(self, coords, normalize_inputs=True):
        """

        :param coords:
        :return:
        """

        if normalize_inputs:
            coords = self.normalize_layer(coords)
        # coords = self.positional_encoding(coords).squeeze()

        # out = self.velocity_branch(coords) * scaling
        out = self.velocity_branch(coords)
        vel = out[..., :2]  # * torch.pow(torch.sin(out[..., [-1]]), 2)
        # pred = self.velocity_branch(coords)
        # magnitude = out[..., [0]] ** 2
        # direction = torch.tanh(out[..., [1]]) * 2 * np.pi

        # vel = torch.cat([magnitude * torch.cos(direction),
        #                  magnitude * torch.sin(direction)], -1)

        # weight = torch.exp(self.prior_torch.log_prob(coords) - self.log_prior_max).reshape(-1, 1)

        return vel  # self.standardize_vf_layer(vel)

    def sqrt_density(self, coords, normalize_inputs=True):
        """

        :param coords_transformed:
        :return:
        """
        if normalize_inputs:
            coords_transformed = self.normalize_layer(coords)
        else:
            coords_transformed = coords

        sqrt_density = self.standardize_mass_layer(self.density_branch(coords_transformed))

        # log_density = self.density_branch(coords)
        weight = torch.exp(self.prior_torch_space.log_prob(coords[..., 1:]) - self.log_prior_max).reshape(-1, 1)
        # weight = torch.exp(self.prior_torch.log_prob(coords) - self.log_prior_max).reshape(-1, 1)

        return sqrt_density * weight  # *  1/(2 * np.pi * self.scaled_radius* 10)

    def density(self, coords, normalize_inputs=True):
        rho = torch.pow(self.sqrt_density(coords, normalize_inputs), 2)
        return self.bounding_act(rho)

    def flux(self, coords, normalize_inputs=True):
        return self.density(coords, normalize_inputs) * self.velocity(coords, normalize_inputs)

    def constraint_loss(self, coords=None, get_losses=False, reweight=False):
        if coords is None:
            coords = self.grid_samples  # [idx]
        coords_normalized = self.normalize_layer(coords)

        rho = self.density(coords_normalized, normalize_inputs=False)
        velocity = self.velocity(coords_normalized, normalize_inputs=False)

        mass_flux = rho * velocity

        div_massflux = divergence(mass_flux, coords, x_offset=1)

        drho_dt = gradient(rho, coords)[..., [0]]

        # loss_consistency = ((drho_dt + div_massflux)).pow(2).mean().sqrt()  # - 1e-2*rho.mean()
        pde_term_sq = (drho_dt + div_massflux).pow(2)
        if reweight:
            pde_term_sq = pde_term_sq / self.sample_weights

        loss_consistency = pde_term_sq.mean().sqrt()  # / (torch.max(rho) - torch.min(rho)).detach()

        # hessian = custom_hessian(log_rho, coords)[0].pow(2).mean().sqrt()
        hessian = 0.

        assert div_massflux.shape == drho_dt.shape
        if get_losses:
            return (drho_dt + div_massflux).pow(2)
        return dict(pde=loss_consistency, hessian=hessian)

    def sample_background(self, sampling_method: ST, num_samples):
        """

        Args:
            sampling_method ():
            num_samples ():

        Returns:

        """
        if sampling_method == ST.uniform:
            coords = self.sobol_engine.draw(num_samples).to(self.device)
            coords = coords * (self.max_extent - self.min_extent) + self.min_extent

        # elif sampling_method == 'rand':
        #     coords = torch.rand([num_samples, self.in_features]).to(self.device)
        elif sampling_method == ST.gaussian:
            coords = self.prior_torch.sample((num_samples,))
        elif sampling_method == ST.dirichlet:
            coords = self.sample_within_convex_hull(num_samples, numpy=False)
        else:
            raise ValueError(f"Unknown background sampling method '{sampling_method}'.")
        return self.allow_input_derivatives(coords)

    def sample_signal_domain(self, num_samples, sampling_method: ST = ST.uniform, burnin=500, rar=False, ot_rar=False):
        sample_weights = torch.ones((num_samples,), device=self.device)
        if sampling_method in [ST.uniform, ST.gaussian, ST.dirichlet]:
            samples = self.sample_background(sampling_method=sampling_method, num_samples=num_samples)
        elif sampling_method in [ST.importance_sampling]:
            n_seeds = max(10_000, num_samples * 10)  # num_samples * 10
            seeds = self.sample_background(sampling_method=ST.uniform, num_samples=n_seeds)
            logits_seeds = torch.log(self.constraint_loss(self.allow_input_derivatives(seeds), get_losses=True))
            # losses_seeds = torch.log(self.density(seeds) + 1e-10)
            knn = KNeighborsRegressor(1, weights="uniform")
            proposals = self.sample_background(sampling_method=ST.uniform, num_samples=100_000)
            logits_proposals = knn.fit(to_numpy(seeds), to_numpy(logits_seeds)).predict(to_numpy(proposals.squeeze()))

            # probs = (torch.abs(losses_samples) / torch.sum(torch.abs(losses_samples))).squeeze()
            # dist = torch.distributions.Categorical(probs=probs)
            dist = torch.distributions.Categorical(logits=to_tensor(logits_proposals, device=self.device).squeeze())
            idx = dist.sample((num_samples,))
            # idx = dist.sample(num_samples, losses_samples)
            # idx = torch.multinomial(probs, num_samples,replacement=True)
            samples = proposals[idx, ...]
            sample_weights = (1 / torch.exp(dist.log_prob(idx))).detach()
        else:
            num_mcmc_samples = int(self.fraction_mcmc * num_samples)
            num_uniform_samples = int((1 - self.fraction_mcmc) * num_samples)

            samples = self.sample_background(sampling_method=ST.dirichlet, num_samples=num_uniform_samples)

            if sampling_method == ST.it_pdpinn:
                # n_seeds = num_samples * 100
                # seeds = self.sample_background(sampling_method="uniform", num_samples=n_seeds)
                # # losses_seeds = torch.log(self.constraint_loss(self.allow_input_derivatives(seeds), get_losses=True))
                # losses_seeds = torch.log(self.density(seeds) + 1e-10)
                # knn = KNeighborsRegressor(1, weights="uniform")
                #
                proposals = self.sample_background(sampling_method=ST.uniform, num_samples=100_000)
                logits_proposals = torch.log(self.density(proposals) + 1e-10)

                # losses_proposals = knn.fit(to_numpy(seeds), to_numpy(losses_seeds)).predict(
                #     to_numpy(proposals.squeeze()))

                # probs = (torch.abs(losses_samples) / torch.sum(torch.abs(losses_samples))).squeeze()
                # dist = torch.distributions.Categorical(probs=probs)
                dist = torch.distributions.Categorical(logits=to_tensor(logits_proposals, device=self.device).squeeze())
                idx = dist.sample((num_samples,))
                # idx = dist.sample(num_samples, losses_samples)
                # idx = torch.multinomial(probs, num_samples,replacement=True)
                samples_mcmc = proposals[idx, ...]
                # sample_weights = (1 / torch.exp(dist.log_prob(idx))).detach()
            else:
                xt_init = self.sample_within_convex_hull(self.sampler.num_chains)
                # xt_init = self.prior_sp.rvs(self.sampler.num_chains)

                samples_mcmc_all = self.sampler.sample_chains(x_init=xt_init, burnin=burnin,
                                                              n_samples=burnin + num_mcmc_samples)
                samples_mcmc = samples_mcmc_all[np.random.randint(samples_mcmc_all.shape[0],
                                                                  size=num_mcmc_samples), ...]  # np.random.choice(samples_mcmc_all, num_mcmc_samples)
                assert samples_mcmc.shape[0] == num_mcmc_samples

                samples_mcmc = to_tensor(samples_mcmc, device=self.device)

                if sampling_method == ST.pdpinn_euler:
                    # dt = torch.sign(torch.rand(samples_mcmc.shape[0], 1, device=self.device)) * 1e-3  # torch.zeros_like(vel), std=torch.ones_like(vel) * 2e-1)
                    # dt = torch.normal(torch.zeros((samples_mcmc.shape[0], 1), device=self.device),
                    #                   std=torch.ones((samples_mcmc.shape[0], 1), device=self.device) * 1e-3)
                    vel = self.velocity(samples_mcmc, normalize_inputs=True)

                    dt = torch.normal(torch.zeros((vel.shape[0], 1), device=self.device, dtype=torch.float32),
                                      std=torch.ones((vel.shape[0], 1), device=self.device) * 1e-1)
                    # dt = 1e-1 * torch.ones(samples_mcmc.shape[0], device=self.device).reshape(-1,1)

                    samples_mcmc_tplusdt = (samples_mcmc + torch.cat([dt, dt * vel], -1))
                    # if samples_mcmc.shape[0] > 50:
                    #     import matplotlib.pyplot as plt
                    #     plt.scatter(*samples_mcmc[:,1:].T.detach().cpu(), alpha=1, marker='.', color="blue")
                    #     plt.quiver(*samples_mcmc[:,1:].T.detach().cpu(), *to_numpy((vel).T),  scale=5, color="orange", scale_units="xy", width=1e-3, headaxislength=3, angles="xy")
                    #     # plt.quiver(*samples_mcmc[:,1:].T.detach().cpu(), *to_numpy(-vel.T),  scale=5, color="orange", scale_units="xy", width=1e-3, headaxislength=3, angles="xy")
                    #     plt.scatter(*samples_mcmc_tplusdt[:,1:].T.detach().cpu(), alpha=1, marker='.', color="red")
                    #     plt.xlim(-5,5)
                    #     plt.show()
                    #     breakpoint()
                    samples_mcmc = samples_mcmc_tplusdt
                elif sampling_method == ST.pdpinn_noise:
                    samples_mcmc = torch.normal(samples_mcmc, std=self.mcmc_noiselevel)
                elif sampling_method != ST.mh_pdpinn:
                    raise ValueError(f"sampling_method='{sampling_method}' is unknown.")
            samples = torch.cat([samples, samples_mcmc], 0)

        if rar:
            samples_cat = torch.cat([self.allow_input_derivatives(samples), self.grid_samples], dim=0)
            losses_cat = self.constraint_loss(samples_cat, get_losses=True).squeeze()
            _, idx = torch.topk(losses_cat, num_samples)
            samples = samples_cat[idx, ...]

        if ot_rar:
            k = num_samples // 4
            samples_searchmax = self.sample_background(num_samples=num_samples, sampling_method=ST.uniform)
            losses_cat = self.constraint_loss(samples_searchmax, get_losses=True)

            _, idx = torch.topk(losses_cat.squeeze(), k=k)
            samples_topk = samples_searchmax[idx]
            samples_uniform_learnmap = self.sample_background(num_samples=k, sampling_method=ST.uniform)
            Ae, be = ot.da.OT_mapping_linear(samples_uniform_learnmap, samples_topk)
            samples = samples @ Ae + be

        return self.allow_input_derivatives(samples), sample_weights

    def sample_within_convex_hull(self, num_samples, numpy=True):
        if numpy:
            A = np.random.dirichlet(np.ones(self.convex_hull_np.shape[0]), (num_samples))
            samples = A @ self.convex_hull_np

            time = np.random.uniform(self.tmin, self.tmax, samples.shape[0]).reshape(-1, 1)
            samples = np.concatenate([time, samples], -1)
        else:
            A = self.dir.sample()
            samples = A @ self.convex_hull_tensor
            time = torch.rand(samples.shape[0], device=self.device).reshape(-1, 1) * (self.tmax - self.tmin) + self.tmin
            samples = torch.cat([time, samples], -1)

        return samples

    def reconstruction_loss(self, coords, target, return_numpy=False):
        velocity_hat, density_hat = self.velocity(coords), self.sqrt_density(coords)

        mask = target[..., [0]] > 1e-3

        # mask_bc = torch.broadcast_to(mask, (mask.shape[0], 2))
        # velocity_hat_masked = torch.masked_select(velocity_hat, mask_bc).reshape(-1, 2)
        # vf_target_masked = torch.masked_select(target[..., 1:], mask_bc).reshape(-1, 2)
        # if velocity_hat_masked.shape[0] < 1:
        #     if to_numpy:
        #         return dict()
        #
        #
        density_dist = torch.distributions.Normal(loc=density_hat, scale=self.density_sd)
        velocity_dist = torch.distributions.MultivariateNormal(loc=velocity_hat,
                                                               scale_tril=self.velocity_cov_tril)
        loss_density = -density_dist.log_prob(target[..., [0]]).mean()

        # loss_velocity = -velocity_dist.log_prob(target[..., 1:] * mask).mean()
        loss_velocity = -velocity_dist.log_prob(target[..., 1:]).mean()

        mse_density = self.mse(input=density_hat, target=target[..., [0]])
        mse_u = self.mse(input=torch.masked_select(velocity_hat[..., [0]], mask),
                         target=torch.masked_select(target[..., [1]], mask))
        mse_v = self.mse(input=torch.masked_select(velocity_hat[..., [1]], mask),
                         target=torch.masked_select(target[..., [2]], mask))
        # loss_w = self.mse(input=velocity_hat[..., [2]], target=target[..., [3]])
        # loss_velocity = (mse_v + mse_u)
        loss_dict = dict(sqrt_density=mse_density, u=mse_u, v=mse_v,
                         velocity=loss_velocity, nlogprob_density=loss_density)
        if return_numpy:
            loss_dict = {key: to_numpy(value) for key, value in loss_dict.items()}
        return loss_dict

    def training_loss(self, coords, target, weight_dict: dict):
        rec_loss = self.reconstruction_loss(coords, target)
        if "pde" in weight_dict:
            constraint_loss = self.constraint_loss(reweight=True)
        else:
            constraint_loss = {"pde": 0., "hessian": 0}

        # boundary_loss = self.boundary_loss(2 ** 2)  # 8)
        loss_dict = {**rec_loss, **constraint_loss}
        total_loss = 0.
        for key, _ in weight_dict.items():
            total_loss += weight_dict[key] * loss_dict[key]

        return total_loss, loss_dict

    def do_sample(self, num_samples, sampling_method: ST = ST.uniform, first=False) -> None:
        if sampling_method != ST.none:
            self.grid_samples, self.sample_weights = self.sample_signal_domain(int(num_samples),
                                                                               sampling_method=sampling_method,
                                                                               rar=self.rar if not first else False,
                                                                               ot_rar=self.ot_rar if not first else False)

    def allow_input_derivatives(self, coords):
        """
        allows to take derivative w.r.t. input
        """
        return coords.clone().detach().requires_grad_(True).to(self.device)
        # return coords.requires_grad_(True)
