from collections import namedtuple

import numpy as np
import torch
from torch import nn

from scipy.stats import multivariate_normal
from scipy.spatial import ConvexHull

from pdPINN.model.model_utiities import gradient, divergence
from pdPINN.model.modules import NormalizeLayer, InverseStandardizeLayer, FCLayer
from pdPINN.model.sampler import *
from pdPINN.model.siren import to_numpy, to_tensor
import ot


class MassCons1d(nn.Module):
    mse = nn.MSELoss()
    networks = namedtuple("networks", "density w")

    def __init__(self,
                 X, Y,
                 tmin, tmax, zmin, zmax,
                 in_features: int,
                 hidden_features=64,
                 vf_hidden_features=64,
                 density_hidden_layers=3,
                 vf_hidden_layers=3,
                 u_0=(-0.1, 0.),
                 u_1=(1.1, 0.),
                 nonlinearity='relu',
                 **kwargs
                 ):
        super().__init__()

        self.u_0 = u_0
        self.u_1 = u_1

        self.tmin, self.tmax = tmin, tmax
        self.zmin, self.zmax = zmin, zmax
        self.min = torch.nn.Parameter(torch.tensor([self.tmin, self.zmin], requires_grad=False))
        self.max = torch.nn.Parameter(torch.tensor([self.tmax, self.zmax], requires_grad=False))

        self.sample_area = torch.nn.Parameter((self.max - self.min).pow(2).prod(), requires_grad=False)

        self.mcmc_noiselevel = kwargs.get("mcmc_noiselevel", 1e-4)

        self.in_features = in_features

        self.sobol = torch.quasirandom.SobolEngine(dimension=in_features, scramble=True)
        self.hidden_features = hidden_features

        self.normalize_layer = NormalizeLayer.from_data(X=X - np.array([0, .3]), ignore_dims=[], trainable=False)
        self.standardize_mass_layer = InverseStandardizeLayer(X=Y[..., [0]], trainable=False)
        self.standardize_vf_layer = InverseStandardizeLayer(X=Y[..., [1]], trainable=False)
        self.prior_sp, self.prior_torch, self.prior_torch_space = self.estimate_gaussian_from_data(X, scale_cov=5)

        self.density_branch = FCLayer(in_features=in_features, out_features=1,
                                      hidden_features=hidden_features, num_hidden_layers=density_hidden_layers,
                                      outermost_linear=True, outermost_positive=False,
                                      nonlinearity=nonlinearity, **kwargs)

        self.w_branch = FCLayer(in_features=in_features, out_features=1, outermost_linear=True,
                                outermost_positive=False,
                                hidden_features=vf_hidden_features, num_hidden_layers=vf_hidden_layers,
                                weight_init=None,
                                nonlinearity=nonlinearity, sine_frequency=7)  # **kwargs)

        self.latent_net = FCLayer(in_features=in_features, out_features=2, outermost_linear=True,
                              outermost_positive=False,
                              hidden_features=vf_hidden_features, num_hidden_layers=vf_hidden_layers,
                              weight_init=None,
                              nonlinearity=nonlinearity, sine_frequency=7)  # **kwargs)

        # self.onehot = torch.tensor([0., 1.], requires_grad=False, device="cuda")
        # self.prior_uniform_sp = uniform(np.array([self.tmin, self.zmin]),
        #                                 np.array([self.tmax, self.zmax]))

        self.sampler = LocalSampler(
            log_p_unnormalized=self._log_p,
            # log_dp_unnormalized=self._log_dpdx,
            # log_p_unnormalized=lambda x: np.log(banana(x).reshape(-1, 1)),
            num_chains=100,
            data_dim=2,
            eps=0.02,
            # prior_sample=lambda size: self.prior_uniform_sp.rvs((size, self.in_features)),
            # prior_logprob = self.prior_sp.logpdf
            prior_sample=self.prior_sp.rvs,
            prior_logprob=self.prior_sp.logpdf
        )

        self.fraction_mcmc = 0.
        self.num_samples = kwargs.get("n_samples_constraints")
        self.rar = kwargs.get("rar")
        self.ot_rar = kwargs.get("ot_rar")
        assert not (self.rar and self.ot_rar), "Only one adaptive refinement method at once allowed."

        self.bounding_act = torch.nn.Hardtanh(min_val=-1e-5, max_val=1e12)

        self.convex_hull_np = np.stack([X.min(0)[[1]], X.max(0)[[1]]], axis=0)
        self.convex_hull_tensor = torch.nn.Parameter(to_tensor(self.convex_hull_np, device=self.device),
                                                     requires_grad=False)

    def init(self):
        self.dir = torch.distributions.Dirichlet(
            to_tensor(np.ones(self.convex_hull_np.shape[0]), device=self.device)).expand((self.num_samples,))

    def estimate_gaussian_from_data(self, X: np.array, scale_cov: float = 0.):
        mu_prior = np.average(X, 0)  # , weights=Y[:, 0])
        cov_prior = np.cov(X, rowvar=False) * (
                scale_cov * np.eye(X.shape[1]) + np.ones(X.shape[1]))  # , aweights=Y[:, 0])

        prior_sp = multivariate_normal(mean=mu_prior, cov=cov_prior)
        prior_torch = torch.distributions.MultivariateNormal(loc=to_tensor(mu_prior, device=self.device),
                                                             covariance_matrix=to_tensor(cov_prior, device=self.device))
        prior_torch_space = torch.distributions.MultivariateNormal(loc=to_tensor(mu_prior[1:], device=self.device),
                                                                   covariance_matrix=to_tensor(cov_prior[1:, 1:],
                                                                                               device=self.device))
        return prior_sp, prior_torch, prior_torch_space

    def _log_p(self, x):
        return to_numpy(self.sample_log_prob(torch.as_tensor(x, dtype=torch.float32).to(self.device)))

    def sample_log_prob(self, x, normalize=True):
        rho = self.density(x, normalize_inputs=normalize) * (
                (x[..., 0] >= self.tmin) * (x[..., 0] <= self.tmax)).reshape(-1, 1)

        return torch.log(rho + 1e-10)
        # constraint_loss = self.constraint_loss(self.allow_input_derivatives(x), get_losses=True)
        # return torch.log(constraint_loss).reshape(-1,1).detach()

    @property
    def device(self):
        return next(self.parameters()).device

    @staticmethod
    def dist_to_boundary(x):
        phi1 = torch.abs(1.1 - x[..., [-1]])  # self.dist_1(x[..., [-1]])
        phi2 = torch.abs(0.0 - x[..., [-1]])  # self.dist_2(x[..., [-1]])
        val = (phi1 * phi2) / (torch.sqrt(phi1 ** 2 + phi2 ** 2))
        x0 = 0.
        k = 200.
        return 1. / (1 + torch.exp(-k * (val - x0)))

    def latent(self, orig_coords, normalize_inputs=True):
        # orig_coords = self.allow_input_derivatives(orig_coords)
        # if normalize_inputs:
        #     coords = self.normalize_layer(orig_coords)
        # else:
        #     coords = orig_coords
        return self.latent_net(orig_coords)

    def sqrt_density(self, orig_coords, normalize_inputs=True):
        """

        :param coords:
        :return:
        """
        orig_coords = self.allow_input_derivatives(orig_coords)

        # import matplotlib.pyplot as plt
        # plt.scatter(*coords.T.detach().cpu(), c=weight.detach().cpu())
        # if normalize_inputs:
        #     coords = self.normalize_layer(orig_coords)
        # else:
        #     coords = orig_coords

        latent = self.latent(orig_coords)
        grad_latent = gradient(latent, orig_coords).squeeze()
        density = grad_latent[..., [1]]

        sqrt_density = torch.abs(density)**0.5

        return sqrt_density  # *  1/(2 * np.pi * self.scaled_radius* 10)

    def density(self, coords, normalize_inputs=True):
        unbounded = torch.square(self.sqrt_density(coords, normalize_inputs))
        return self.bounding_act(unbounded)

    def velocity(self, orig_coords, normalize_inputs=True):
        """

        :param coords:
        :return:
        """
        orig_coords = self.allow_input_derivatives(orig_coords)

        # coords = coords + torch.rand(1, device="cuda") * 0.0458333 * self.onehot

        latent = self.latent(orig_coords)
        grad_latent = gradient(latent, orig_coords).squeeze()
        w = -grad_latent[..., [0]] / (grad_latent[..., [1]]+1e-10)

        return w

    def flux(self, coords, normalize_inputs=True):
        return self.density(coords, normalize_inputs) * self.velocity(coords, normalize_inputs)

    def constraint_loss(self, coords=None, get_losses=False):
        if coords is None:
            coords = self.grid_samples

        rho = self.density(coords, normalize_inputs=True)
        velocity = self.velocity(coords, normalize_inputs=True)

        mass_flux = rho * velocity
        div_massflux = divergence(mass_flux, coords, x_offset=1).squeeze()

        drho_dt = gradient(rho, coords)[..., 0].squeeze()
        assert div_massflux.shape == drho_dt.shape

        norm = (torch.max(rho) - torch.min(rho)).detach()
        # *self.sample_log_prob(coords, normalize=False).detach()
        # weight = self.sample_log_prob(coords, normalize=False).detach()

        pde_loss = (drho_dt + div_massflux).pow(2)
        loss_consistency = (pde_loss.mean()).sqrt() / norm
        # loss_consistency = ((((drho_dt + div_massflux).pow(2)*weight).sum()/weight.sum()) / norm).sqrt()  # weighted uniform
        # loss_consistency = (((drho_dt + div_massflux).pow(2)).mean()/ norm).sqrt()  # weighted uniform
        # loss_consistency = (((drho_dt + div_massflux).pow(2).sqrt()).mean()) / norm
        # loss_consistency = (((drho_dt + div_massflux).pow(2)*self.grid_probs).sum()/self.grid_probs.sum()).sqrt() / norm/60
        # loss_consistency = (((drho_dt + div_massflux).pow(2)*torch.ones_like(self.grid_probs)).sum()/torch.ones_like(self.grid_probs).sum()).sqrt() / norm/60
        if get_losses:
            return pde_loss
        return dict(cons_of_mass=loss_consistency, hessian=0.)

    def sample_background(self, num_samples: int, sampling_method: str = "gaussian"):
        """
        Sample from the background, i.e. either uniform or according to the gaussian distribution fitted around
        the observations.
        Args:
            num_samples (int): number of samples to draw
            sampling_method (str): "gaussian" or "uniform"

        Returns:

        """
        if sampling_method == "gaussian":
            coords = self.prior_torch.sample((num_samples,))
        elif sampling_method == "uniform":
            coords = self.sobol.draw(num_samples).to(self.device)
            coords = coords * (self.max - self.min) + self.min
        elif sampling_method == 'dirichlet':
            coords = to_tensor(self.sample_within_convex_hull(num_samples, numpy=True), device=self.device)
        else:
            raise ValueError(f"Unknown background method '{sampling_method}'.")
        return self.allow_input_derivatives(coords)

    def sample_signal_domain(self, num_samples, sampling_method, burnin=500, rar=False, ot_rar=False):

        if self.fraction_mcmc <= 0. or sampling_method in ["uniform", "gaussian", "dirichlet", "importance_sampling"]:
            samples = self.sample_background(num_samples, sampling_method)
            probs = torch.ones(samples.shape[0], 1, device=self.device)
        else:
            num_mcmc_samples = round(self.fraction_mcmc * num_samples)
            num_uniform_samples = round((1 - self.fraction_mcmc) * num_samples)

            samples = self.sample_background(num_uniform_samples, "gaussian")
            # xt_init = self.prior_sp.rvs(self.sampler.num_chains)
            xt_init = self.sample_within_convex_hull(self.sampler.num_chains)

            samples_mcmc_all = self.sampler.sample_chains(x_init=xt_init, burnin=burnin,
                                                          n_samples=burnin + num_mcmc_samples)

            idx = np.random.randint(samples_mcmc_all.shape[0], size=num_mcmc_samples)
            samples_mcmc = to_tensor(samples_mcmc_all[idx], device=self.device)
            assert samples_mcmc.shape[0] == num_mcmc_samples
            if sampling_method == "mcmc_euler":
                # dt = torch.sign(torch.rand(samples_mcmc.shape[0], 1, device=self.device)) * 1e-3  # torch.zeros_like(vel), std=torch.ones_like(vel) * 2e-1)
                # dt = torch.normal(torch.zeros((samples_mcmc.shape[0], 1), device=self.device),
                #                   std=torch.ones((samples_mcmc.shape[0], 1), device=self.device) * 1e-3)
                # import matplotlib.pyplot as plt
                # plt.scatter(*samples_mcmc.T.detach().cpu(), alpha=1, marker='+', color="blue")
                vel = self.velocity(samples_mcmc)
                dt = torch.normal(torch.zeros_like(vel), std=torch.ones_like(vel) * 2e-1)
                samples_mcmc = (samples_mcmc + torch.cat([dt, dt * vel], -1))
                # plt.scatter(*samples_mcmc.T.detach().cpu(), alpha=1, marker='+', color="red")
            elif sampling_method == "mcmc_noise":
                samples_mcmc = torch.normal(samples_mcmc, std=self.mcmc_noiselevel)
            elif sampling_method != "mcmc":
                raise ValueError(f"sampling_method='{sampling_method}' is unknown.")
            samples = torch.cat([samples, samples_mcmc], 0)

            probs = torch.cat([torch.ones(samples.shape[0], 1, device=self.device),
                               self.sample_area / self.sample_log_prob(samples_mcmc)], 0).detach()

        if rar:
            samples_cat = torch.cat([self.allow_input_derivatives(samples), self.grid_samples], dim=0)
            losses_cat = self.constraint_loss(samples_cat, get_losses=True)
            _, idx = torch.topk(losses_cat, num_samples)
            samples = samples_cat[idx, ...]
        if ot_rar:
            k = num_samples // 4
            samples_searchmax = self.sample_background(num_samples, "uniform")
            losses_cat = self.constraint_loss(samples_searchmax, get_losses=True)
            _, idx = torch.topk(losses_cat, k=k)
            samples_topk = samples_searchmax[idx]
            samples_uniform_learnmap = self.sample_background(k, "uniform")
            Ae, be = ot.da.OT_mapping_linear(samples_uniform_learnmap, samples_topk)
            samples = samples @ Ae + be

        if sampling_method == "importance_sampling":
            losses_samples = self.constraint_loss(self.allow_input_derivatives(samples), get_losses=True)
            breakpoint()
            idx = torch.multinomial(torch.abs(losses_samples) / torch.sum(torch.abs(losses_samples)), num_samples,
                              replacement=True)
            samples = samples[idx, ...]

            # probs = torch.ones(samples.shape[0], 1, device=self.device)
        assert samples.shape[0] == num_samples, f"{samples.shape[0]} instead of {num_samples}."
        return self.allow_input_derivatives(samples), probs

    def do_sample(self, num_samples, sampling_method, first=False) -> None:
        self.grid_samples, self.grid_probs = self.sample_signal_domain(num_samples, sampling_method=sampling_method,
                                                                       rar=self.rar if not first else False,
                                                                       ot_rar=self.ot_rar if not first else False)

    def sample_within_convex_hull(self, num_samples, numpy=True):
        if numpy:
            A = np.random.dirichlet(np.ones(self.convex_hull_np.shape[0]), (num_samples))
            samples = A @ self.convex_hull_np

            time = np.random.uniform(self.tmin, self.tmax, samples.shape[0]).reshape(-1, 1)
            samples = np.concatenate([time, samples], -1)
        else:
            A = self.dir.sample()
            samples = A @ self.convex_hull_tensor
            time = torch.rand(samples.shape[0], device=self.device).reshape(-1, 1) * (self.tmax - self.tmin) + self.tmin
            samples = torch.cat([time, samples], -1)
        return samples

    @staticmethod
    def weighted_mse_loss(input, target, weight):
        return (weight * (input - target) ** 2).mean()

    def reconstruction_loss(self, coords, target):
        w_hat, sqrt_density_hat = self.velocity(coords), self.sqrt_density(coords)
        loss_density = self.mse(input=sqrt_density_hat, target=target[..., [0]])
        mask = target[..., [0]] > 0
        # loss_w = self.mse(input=w_hat, target=target[..., [1]])
        loss_w = self.mse(input=torch.masked_select(w_hat, mask), target=torch.masked_select(target[..., [1]], mask))

        return dict(density=1e4*loss_density, w=loss_w,
                    mse_density_weighted=self.mse(sqrt_density_hat, target[..., [0]])
                    )

    def training_loss(self, coords, target, constraint_weight=1e-5, include_w=True):
        rec_loss = self.reconstruction_loss(coords, target)
        if constraint_weight == 0:
            constraint_loss = dict(cons_of_mass=0., hessian=0.)
        else:
            constraint_loss = self.constraint_loss()
        loss_dict = {**rec_loss, **constraint_loss, }

        total_loss = (0.15 * rec_loss["density"] + int(include_w) * 1e4 * rec_loss["w"]
                      + constraint_weight * constraint_loss["cons_of_mass"]
                      + 4e-2 * constraint_loss["hessian"]
                      )
        return total_loss, loss_dict

    def allow_input_derivatives(self, coords):
        """
        allows to take derivative w.r.t. input
        """
        return coords.clone().detach().requires_grad_(True).to(self.device)
