import itertools
import math
import os
from itertools import cycle

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

import wandb

from .stein_gradient_estimator import SpectralSteinEstimator


class MeasurementGenerator:
    def __init__(self, train_set):
        self.train_iterator = iter(train_set)

        self.data_iterator = cycle(self.train_iterator)

    def get(self, n_data):
        theta = None
        x = None

        # Create samples from the joint
        while (
            x is None
            or theta is None
            or x.shape[0] < n_data / 2
            or theta.shape[0] < n_data / 2
        ):
            theta_cur, x_cur = next(self.data_iterator)

            if theta is None:
                theta = theta_cur
            else:
                theta = torch.cat((theta, theta_cur), dim=0)

            if x is None:
                x = x_cur
            else:
                x = torch.cat((x, x_cur), dim=0)

        # random permute in order to not always select the same samples in a batch.
        perm = torch.randperm(x.shape[0])
        x = x[perm, ...]
        theta = theta[perm, ...]

        # Cut to only keep the required amount of samples
        x = x[: math.ceil(n_data / 2), ...]
        theta = theta[: math.ceil(n_data / 2), ...]

        # Create samples from the marginal
        theta_prime = torch.roll(theta, 1, dims=0)

        theta = torch.cat((theta, theta_prime), dim=0)
        x = torch.cat((x, x), dim=0)

        return theta, x


class Critic(nn.Module):
    def __init__(self, dim, critic_config, classifier=False):
        super(Critic, self).__init__()
        nb_layers = critic_config["nb_layers"]
        nb_neurons = critic_config["nb_neurons"]

        self.head = []

        if nb_layers == 0:
            self.head.append(nn.Linear(dim, 1))
            self.head.append(nn.Softplus())

        else:
            self.head.append(nn.Linear(dim, nb_neurons))
            for _ in range(nb_layers - 1):
                self.head.append(nn.Linear(nb_neurons, nb_neurons))
                self.head.append(nn.Softplus())

            self.head.append(nn.Linear(nb_neurons, 1))
            if classifier:
                self.head.append(nn.Sigmoid())

        self.head = nn.ModuleList(self.head)

    def forward(self, x):
        for layer in self.head:
            x = layer(x)
        return x


class CriticWithInputs(nn.Module):
    def __init__(self, measurement_size, benchmark, critic_config, classifier=False):
        super(CriticWithInputs, self).__init__()

        nb_layers = critic_config["nb_layers"]
        nb_neurons = critic_config["nb_neurons"]
        self.batch_size = critic_config["batch_size"]

        self.observable_shape = benchmark.get_observable_shape()
        self.embedding_dim = benchmark.get_embedding_dim()
        self.parameter_dim = benchmark.get_parameter_dim()
        output_dim = 1
        dim = self.embedding_dim + self.parameter_dim + output_dim

        embedding_build = benchmark.get_embedding_build()
        self.embedding = embedding_build(self.embedding_dim, self.observable_shape)

        if critic_config["model"] == "deep_set":
            self.deep_set = []
            if nb_layers == 0:
                self.deep_set.append(nn.Linear(dim, 1))
                self.deep_set.append(nn.Softplus())
            else:
                self.deep_set.append(nn.Linear(dim, nb_neurons))
                for _ in range(nb_layers):
                    self.deep_set.append(nn.Linear(nb_neurons, nb_neurons))
                    self.deep_set.append(nn.Softplus())

            self.deep_set = nn.ModuleList(self.deep_set)

            self.model_type = "deep_set"
            head_dim = nb_neurons

        elif critic_config["model"] == "mlp":
            self.model_type = "mlp"
            head_dim = dim * measurement_size

        else:
            raise NotImplementedError("Lipschitz model not implemented.")

        self.head = []
        if nb_layers == 0:
            self.head.append(nn.Linear(head_dim, 1))
            self.head.append(nn.Softplus())

        else:
            self.head.append(nn.Linear(head_dim, nb_neurons))
            for _ in range(nb_layers - 1):
                self.head.append(nn.Linear(nb_neurons, nb_neurons))
                self.head.append(nn.Softplus())

            self.head.append(nn.Linear(nb_neurons, 1))
            if classifier:
                self.head.append(nn.Sigmoid())

        self.head = nn.ModuleList(self.head)

    def forward(self, theta, x, out):
        # theta, x, out of size n_samples, n_mesurement, dim

        n_samples = x.shape[0]
        n_measurement = x.shape[1]
        x = x.view((n_samples * n_measurement, -1))
        embed = self.embedding(x)
        embed = embed.view((n_samples, n_measurement, -1))
        y = torch.cat((theta, embed, out), dim=2)

        if self.model_type == "deep_set":
            # Process all the measurements independently
            y = y.view((n_samples * n_measurement, -1))
            for layer in self.deep_set:
                y = layer(y)

            # Average the features of all measurements to make it permutation invariant.
            y = y.view((n_samples, n_measurement, -1))
            y = torch.mean(y, dim=1)

        if self.model_type == "mlp":
            # Concat all measurements in a single feature vector
            y = y.view(n_samples, -1)

        for layer in self.head:
            y = layer(y)

        return y


def weights_init(m):
    """Reproduced from https://github.com/tranbahien/you-need-a-good-prior"""
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight.data)
        torch.nn.init.normal_(m.bias.data)


class WassersteinDistance:
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior"""

    def __init__(
        self,
        bnn,
        np_model,
        measurement_size,
        critic_config,
        benchmark,
        output_dim,
        use_lipschitz_constraint=True,
        lipschitz_constraint_type="gp",
        wasserstein_lr=0.01,
        device="cpu",
        gpu_np_model=True,
        threshold=None,
        n_steps=10,
        data_generator=None,
        penalty_coef=10,
        restart_lipschitz=False,
        clip_gradient=False,
        clipping_norm=None,
    ):
        self.bnn = bnn
        self.np_model = np_model
        self.device = device
        self.output_dim = output_dim
        self.measurement_size = measurement_size
        self.n_steps = n_steps
        self.data_generator = data_generator
        self.restart_lipschitz = restart_lipschitz
        self.lipschitz_constraint_type = lipschitz_constraint_type
        assert self.lipschitz_constraint_type in ["gp", "lp"]
        self.clip_gradient = clip_gradient
        self.clipping_norm = clipping_norm

        self.lipschitz_f = CriticWithInputs(measurement_size, benchmark, critic_config)
        self.lipschitz_f = self.lipschitz_f.to(self.device)
        self.lipschitz_f.apply(weights_init)
        self.gpu_np_model = gpu_np_model
        self.values_log = []

        self.optimiser = torch.optim.Adagrad(
            self.lipschitz_f.parameters(), lr=wasserstein_lr
        )
        self.use_lipschitz_constraint = use_lipschitz_constraint
        self.penalty_coef = penalty_coef
        self.threshold = threshold

    def calculate(self, theta, x, nnet_samples, np_samples):
        d = 0.0
        theta = theta.repeat(nnet_samples.shape[0], 1, 1)
        x = x.repeat(nnet_samples.shape[0], 1, 1)
        for dim in range(self.output_dim):
            f_samples = self.lipschitz_f(
                theta, x, nnet_samples[:, :, dim].unsqueeze(dim=2)
            )
            f_np = self.lipschitz_f(theta, x, np_samples[:, :, dim].unsqueeze(dim=2))
            d += torch.mean(f_samples, 0) - torch.mean(f_np, 0)
        return d

    def compute_gradient_penalty(self, theta, x, samples_p, samples_q):
        theta = theta.repeat(samples_p.shape[0], 1, 1)
        x = x.repeat(samples_p.shape[0], 1, 1)

        eps = torch.rand((samples_p.shape[0], 1)).to(samples_p.device)
        out = eps * samples_p.detach() + (1 - eps) * samples_q.detach()
        out = out.unsqueeze(dim=2)

        theta.requires_grad = True
        x.requires_grad = True
        out.requires_grad = True
        Y = self.lipschitz_f(theta, x, out)

        gradients_out = torch.autograd.grad(
            Y,
            out,
            grad_outputs=torch.ones(Y.size(), device=self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients_out = gradients_out.squeeze()
        f_gradient_norm = gradients_out.norm(2, dim=1)

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            return ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            return ((torch.clamp(f_gradient_norm - 1, 0.0, np.inf)) ** 2).mean()

        else:
            raise NotImplementedError("Lipschitz constraint not implemented.")

    def reset_distance(self):
        self.lipschitz_f.apply(weights_init)

    def distance_optimisation(self, theta, x, n_samples, debug=False):
        if self.restart_lipschitz:
            self.reset_distance()

        for p in self.lipschitz_f.parameters():
            p.requires_grad = True

        for i in range(self.n_steps):

            print("wasserstein step {}".format(i))
            if self.data_generator:
                theta, x = self.data_generator.get(self.measurement_size)

            if self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)
            else:
                x = x.to("cpu")
                theta = theta.to("cpu")

            # Draw functions from GP
            np_samples = (
                self.np_model.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            x = x.to(self.device)
            theta = theta.to(self.device)

            self.optimiser.zero_grad()

            # Draw functions from Bayesian Neural network
            nnet_samples = (
                self.bnn.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            #  It was of size: [n_dim, N, n_out]
            # will be of size: [N, n_dim, n_out]
            np_samples = np_samples.transpose(0, 1)
            nnet_samples = nnet_samples.transpose(0, 1)

            objective = -self.calculate(theta, x, nnet_samples, np_samples)

            if debug:
                self.values_log.append(-objective.item())

            if self.use_lipschitz_constraint:
                penalty = 0.0
                for dim in range(self.output_dim):
                    penalty += self.compute_gradient_penalty(
                        theta, x, nnet_samples[:, :, dim], np_samples[:, :, dim]
                    )
                objective += self.penalty_coef * penalty
            objective.backward()

            if self.threshold is not None:
                # Gradient Norm
                params = self.lipschitz_f.parameters()
                grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()

            if self.clip_gradient:
                nn.utils.clip_grad_norm_(
                    self.lipschitz_f.parameters(), self.clipping_norm
                )

            self.optimiser.step()
            if not self.use_lipschitz_constraint:
                for p in self.lipschitz_f.parameters():
                    p.data = torch.clamp(p, -0.1, 0.1)
            if self.threshold is not None and grad_norm < self.threshold:
                print("WARNING: Grad norm (%.3f) lower than threshold (%.3f). ", end="")
                print("Stopping optimization at step %d" % (i))
                if debug:
                    # '-1' because the last wssr value is not recorded
                    self.values_log = self.values_log + [self.values_log[-1]] * (
                        self.n_steps - i - 1
                    )
                break
        for p in self.lipschitz_f.parameters():
            p.requires_grad = False

    def is_adversarial(self):
        return True


class WassersteinDistanceOld:
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior"""

    def __init__(
        self,
        bnn,
        np_model,
        lipschitz_f_dim,
        critic_config,
        output_dim,
        use_lipschitz_constraint=True,
        lipschitz_constraint_type="gp",
        wasserstein_lr=0.01,
        device="cpu",
        gpu_np_model=True,
        threshold=None,
        n_steps=10,
        restart_lipschitz=False,
        clip_gradient=False,
        clipping_norm=None,
    ):
        self.bnn = bnn
        self.np_model = np_model
        self.device = device
        self.output_dim = output_dim
        self.lipschitz_f_dim = lipschitz_f_dim
        self.threshold = threshold
        self.n_steps = n_steps
        self.restart_lipschitz = restart_lipschitz
        self.lipschitz_constraint_type = lipschitz_constraint_type
        assert self.lipschitz_constraint_type in ["gp", "lp"]
        self.clip_gradient = clip_gradient
        self.clipping_norm = clipping_norm

        self.lipschitz_f = Critic(lipschitz_f_dim, critic_config)
        self.lipschitz_f = self.lipschitz_f.to(self.device)
        self.lipschitz_f.apply(weights_init)
        self.gpu_np_model = gpu_np_model
        self.values_log = []

        self.optimiser = torch.optim.Adagrad(
            self.lipschitz_f.parameters(), lr=wasserstein_lr
        )
        self.use_lipschitz_constraint = use_lipschitz_constraint
        self.penalty_coef = 10

    def calculate(self, nnet_samples, np_samples):
        d = 0.0
        for dim in range(self.output_dim):
            f_samples = self.lipschitz_f(nnet_samples[:, :, dim].T)
            f_np = self.lipschitz_f(np_samples[:, :, dim].T)
            d += torch.mean(torch.mean(f_samples, 0) - torch.mean(f_np, 0))
        return d

    def compute_gradient_penalty(self, samples_p, samples_q):
        eps = torch.rand(samples_p.shape[1], 1).to(samples_p.device)
        X = eps * samples_p.t().detach() + (1 - eps) * samples_q.t().detach()
        X.requires_grad = True
        Y = self.lipschitz_f(X)
        gradients = torch.autograd.grad(
            Y,
            X,
            grad_outputs=torch.ones(Y.size(), device=self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        f_gradient_norm = gradients.norm(2, dim=1)

        if self.lipschitz_constraint_type == "gp":
            # Gulrajani2017, Improved Training of Wasserstein GANs
            return ((f_gradient_norm - 1) ** 2).mean()

        elif self.lipschitz_constraint_type == "lp":
            # Henning2018, On the Regularization of Wasserstein GANs
            # Eq (8) in Section 5
            return ((torch.clamp(f_gradient_norm - 1, 0.0, np.inf)) ** 2).mean()

        else:
            raise NotImplementedError("Lipschitz constraint not implemented.")

    def reset_distance(self):
        self.lipschitz_f.apply(weights_init)

    def distance_optimisation(self, theta, x, n_samples, debug=False):
        if self.restart_lipschitz:
            self.reset_distance()

        for p in self.lipschitz_f.parameters():
            p.requires_grad = True

        n_samples_bag = n_samples * 1
        if not self.gpu_np_model:
            x = x.to("cpu")
            theta = theta.to("cpu")

        # Draw functions from GP
        np_samples_bag = (
            self.np_model.sample_functions(theta, x, n_samples_bag)
            .detach()
            .float()
            .to(self.device)
        )
        if self.output_dim > 1:
            np_samples_bag = np_samples_bag.squeeze()

        if not self.gpu_np_model:
            x = x.to(self.device)
            theta = theta.to(self.device)

        # Draw functions from Bayesian Neural network
        nnet_samples_bag = (
            self.bnn.sample_functions(theta, x, n_samples_bag)
            .detach()
            .float()
            .to(self.device)
        )
        if self.output_dim > 1:
            nnet_samples_bag = nnet_samples_bag.squeeze()

        #  It was of size: [n_dim, N, n_out]
        # will be of size: [N, n_dim, n_out]
        np_samples_bag = np_samples_bag.transpose(0, 1)
        nnet_samples_bag = nnet_samples_bag.transpose(0, 1)
        dataset = TensorDataset(np_samples_bag, nnet_samples_bag)
        data_loader = DataLoader(dataset, batch_size=n_samples, num_workers=0)
        batch_generator = itertools.cycle(data_loader)

        for i in range(self.n_steps):
            np_samples, nnet_samples = next(batch_generator)
            #         was of size: [N, n_dim, n_out]
            # needs to be of size: [n_dim, N, n_out]
            np_samples = np_samples.transpose(0, 1)
            nnet_samples = nnet_samples.transpose(0, 1)

            self.optimiser.zero_grad()
            objective = -self.calculate(nnet_samples, np_samples)
            if debug:
                self.values_log.append(-objective.item())

            if self.use_lipschitz_constraint:
                penalty = 0.0
                for dim in range(self.output_dim):
                    penalty += self.compute_gradient_penalty(
                        nnet_samples[:, :, dim], np_samples[:, :, dim]
                    )
                objective += self.penalty_coef * penalty
            objective.backward()

            if self.threshold is not None:
                # Gradient Norm
                params = self.lipschitz_f.parameters()
                grad_norm = torch.cat([p.grad.data.flatten() for p in params]).norm()

            if self.clip_gradient:
                nn.utils.clip_grad_norm_(
                    self.lipschitz_f.parameters(), self.clipping_norm
                )

            self.optimiser.step()
            if not self.use_lipschitz_constraint:
                for p in self.lipschitz_f.parameters():
                    p.data = torch.clamp(p, -0.1, 0.1)
            if self.threshold is not None and grad_norm < self.threshold:
                print("WARNING: Grad norm (%.3f) lower than threshold (%.3f). ", end="")
                print("Stopping optimization at step %d" % (i))
                if debug:
                    # '-1' because the last wssr value is not recorded
                    self.values_log = self.values_log + [self.values_log[-1]] * (
                        self.n_steps - i - 1
                    )
                break
        for p in self.lipschitz_f.parameters():
            p.requires_grad = False

    def is_adversarial(self):
        return True


class DiscriminatorDistance:
    def __init__(
        self,
        bnn,
        np_model,
        measurement_size,
        critic_config,
        benchmark,
        output_dim,
        discriminator_lr=0.01,
        device="cpu",
        gpu_np_model=True,
        n_steps=10,
        data_generator=None,
        restart_discriminator=False,
        regularize_discriminator=False,
        penalty_coef=None,
        clip_gradient=False,
        clipping_norm=None,
    ):
        self.bnn = bnn
        self.np_model = np_model
        self.device = device
        self.output_dim = output_dim
        self.measurement_size = measurement_size
        self.n_steps = n_steps
        self.data_generator = data_generator
        self.restart_discriminator = restart_discriminator
        self.regularize_discriminator = regularize_discriminator
        self.penalty_coef = penalty_coef
        self.clip_gradient = clip_gradient
        self.clipping_norm = clipping_norm

        self.discriminator = CriticWithInputs(
            measurement_size, benchmark, critic_config
        )
        self.discriminator = self.discriminator.to(self.device)
        self.discriminator.apply(weights_init)
        self.gpu_np_model = gpu_np_model
        self.values_log = []

        self.optimiser = torch.optim.Adagrad(
            self.discriminator.parameters(), lr=discriminator_lr
        )

    def calculate(self, theta, x, nnet_samples, np_samples):
        d = 0.0
        theta = theta.repeat(nnet_samples.shape[0], 1, 1)
        x = x.repeat(nnet_samples.shape[0], 1, 1)
        for dim in range(self.output_dim):
            f_nnet = self.discriminator(
                theta, x, nnet_samples[:, :, dim].unsqueeze(dim=2)
            )
            f_np = self.discriminator(theta, x, np_samples[:, :, dim].unsqueeze(dim=2))
            d_nnet = F.binary_cross_entropy_with_logits(
                f_nnet, f_nnet.new_full(size=f_nnet.size(), fill_value=1)
            )
            d_np = F.binary_cross_entropy_with_logits(
                f_np, f_np.new_full(size=f_np.size(), fill_value=0)
            )
            d += -(d_nnet + d_np) / 2

        return d

    def reset_distance(self):
        self.discriminator.apply(weights_init)

    def compute_gradient_penalty(self, theta, x, samples_np):
        theta = theta.repeat(samples_np.shape[0], 1, 1)
        x = x.repeat(samples_np.shape[0], 1, 1)

        samples_np = samples_np.unsqueeze(dim=2)

        theta.requires_grad = True
        x.requires_grad = True
        samples_np.requires_grad = True
        Y = torch.sigmoid(self.discriminator(theta, x, samples_np))

        gradients_samples_np = torch.autograd.grad(
            Y,
            samples_np,
            grad_outputs=torch.ones(Y.size(), device=self.device),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]

        gradients_samples_np = gradients_samples_np.squeeze()
        f_gradient_norm = gradients_samples_np.norm(2, dim=1)

        return (f_gradient_norm**2).mean()

    def distance_optimisation(self, theta, x, n_samples, debug=False):
        if self.restart_discriminator:
            self.reset_distance()

        for p in self.discriminator.parameters():
            p.requires_grad = True

        for i in range(self.n_steps):

            print("Discriminator step {}".format(i))
            if self.data_generator:
                theta, x = self.data_generator.get(self.measurement_size)

            if self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)
            else:
                x = x.to("cpu")
                theta = theta.to("cpu")

            # Draw functions from GP
            np_samples = (
                self.np_model.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            x = x.to(self.device)
            theta = theta.to(self.device)

            self.optimiser.zero_grad()

            # Draw functions from Bayesian Neural network
            nnet_samples = (
                self.bnn.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            #  It was of size: [n_dim, N, n_out]
            # will be of size: [N, n_dim, n_out]
            np_samples = np_samples.transpose(0, 1)
            nnet_samples = nnet_samples.transpose(0, 1)

            objective = -self.calculate(theta, x, nnet_samples, np_samples)

            if debug:
                self.values_log.append(-objective.item())

            if self.regularize_discriminator:
                penalty = 0.0
                for dim in range(self.output_dim):
                    penalty += self.compute_gradient_penalty(
                        theta, x, np_samples[:, :, dim]
                    )
                objective += self.penalty_coef * penalty

            if self.clip_gradient:
                nn.utils.clip_grad_norm_(
                    self.discriminator.parameters(), self.clipping_norm
                )

            objective.backward()

            self.optimiser.step()

        for p in self.discriminator.parameters():
            p.requires_grad = False

    def is_adversarial(self):
        return True


class KLDistance:
    def __init__(self, np_model, output_dim, gpu_np_model=False, device="cpu"):
        self.np_model = np_model
        self.output_dim = output_dim
        self.gpu_np_model = gpu_np_model
        self.device = device

    def calculate(self, theta, x, nnet_samples):
        #  It was of size: [N, n_dim, n_out]
        # will be of size: [n_dim, N, n_out]
        nnet_samples = nnet_samples.transpose(0, 1)
        d = 0.0
        for dim in range(self.output_dim):
            log_probs = self.np_model.functions_log_prob(
                theta, x, nnet_samples[:, :, dim]
            )

            # E_X E_BNN, careful when computing the entropy to only take it on E_BNN.
            d -= log_probs.mean()

        return d

    def is_adversarial(self):
        return False


class SteinKLDistance:
    def __init__(
        self, np_model, output_dim, num_eigs, eta, gpu_np_model=False, device="cpu"
    ):
        self.np_model = np_model
        self.output_dim = output_dim
        self.gpu_np_model = gpu_np_model
        self.device = device
        self.score_estimator = SpectralSteinEstimator(eta, num_eigs)

    def calculate(self, theta, x, nnet_samples):
        #  It was of size: [N, n_dim, n_out]
        # will be of size: [n_dim, N, n_out]
        nnet_samples = nnet_samples.transpose(0, 1)
        cross_entropy = 0.0
        for dim in range(self.output_dim):
            log_probs = self.np_model.functions_log_prob(
                theta, x, nnet_samples[:, :, dim]
            )

            # E_X E_BNN, careful when computing the entropy to only take it on E_BNN.
            cross_entropy -= log_probs.mean()

        entropy = 0.0
        for dim in range(self.output_dim):
            with torch.no_grad():
                dlog_q = self.score_estimator(nnet_samples[:, :, dim])

            dim_entropy = -dlog_q * nnet_samples[:, :, dim]
            # TODO: mean over everything?
            entropy += dim_entropy.mean()

        return -entropy + cross_entropy

    def is_adversarial(self):
        return False


class PriorMapper(object):
    def __init__(
        self,
        np_model,
        bnn,
        data_generator,
        out_dir,
        log_space,
        output_dim=1,
        n_data=256,
        gpu_np_model=False,
        logger=None,
        device="cpu",
        clip_gradient=False,
        clipping_norm=None,
        use_wandb=False,
    ):

        self.np_model = np_model
        self.bnn = bnn
        self.data_generator = data_generator
        self.n_data = n_data
        self.output_dim = output_dim
        self.out_dir = out_dir
        self.log_space = log_space
        self.device = device
        self.gpu_np_model = gpu_np_model
        self.clip_gradient = clip_gradient
        self.clipping_norm = clipping_norm
        self.use_wandb = use_wandb

        # Move models to configured device
        if gpu_np_model:
            self.np_model = self.np_model.to(self.device)
        self.bnn = self.bnn.to(self.device)

        # Setup logger
        self.print_info = (
            lambda x: print(x, flush=True) if logger is None else logger.info
        )

    def compute_mmd_given_samples(self, nnet_samples, np_samples):
        # Samples of shape [n samples, n_measurements]
        number_samples = math.floor(nnet_samples.shape[0] / 2)

        nnet_samples_1 = nnet_samples[:number_samples, ...]
        nnet_samples_2 = nnet_samples[number_samples:, ...]
        np_samples_1 = np_samples[:number_samples, ...]
        np_samples_2 = np_samples[number_samples:, ...]

        lengthscale = math.sqrt(nnet_samples_1.shape[1])

        def rbf_kernel(p_samples, q_samples):
            # Samples of shape [n samples, n_measurements]
            return torch.exp(
                -torch.sum((p_samples - q_samples) ** 2, dim=1) / (2 * lengthscale**2)
            )
            # Return Tensor of shape [n_samples]

        term_1 = torch.mean(rbf_kernel(np_samples_1, np_samples_2)).item()
        term_2 = torch.mean(rbf_kernel(nnet_samples_1, nnet_samples_2)).item()
        term_3 = torch.mean(rbf_kernel(nnet_samples_1, np_samples_1)).item()

        return term_1 + term_2 - 2 * term_3

    def compute_mmd(self, theta, x, n_samples):
        with torch.no_grad():
            # theta, x = self.data_generator.get(n_data)

            x = x.to(self.device)
            theta = theta.to(self.device)
            if not self.gpu_np_model:
                x = x.to("cpu")
                theta = theta.to("cpu")

            np_samples = (
                self.np_model.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            if not self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)

            nnet_samples = (
                self.bnn.sample_functions(theta, x, n_samples).float().to(self.device)
            )
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            return self.compute_mmd_given_samples(nnet_samples, np_samples)


class DistanceBasedPriorMapper(PriorMapper):
    """Adapted from https://github.com/tranbahien/you-need-a-good-prior"""

    def __init__(
        self,
        np_model,
        bnn,
        data_generator,
        out_dir,
        log_space,
        distance_type,
        distance_config,
        output_dim=1,
        n_data=256,
        gpu_np_model=False,
        logger=None,
        device="cpu",
        clip_gradient=False,
        clipping_norm=None,
        use_wandb=False,
    ):

        super().__init__(
            np_model,
            bnn,
            data_generator,
            out_dir,
            log_space,
            output_dim=output_dim,
            n_data=n_data,
            gpu_np_model=gpu_np_model,
            logger=logger,
            device=device,
            clip_gradient=clip_gradient,
            clipping_norm=clipping_norm,
            use_wandb=use_wandb,
        )

        self.distance_type = distance_type

        # Initialize the distance
        if distance_type == "wasserstein_old":
            self.distance = WassersteinDistanceOld(
                self.bnn,
                self.np_model,
                self.n_data,
                distance_config["critic_config"],
                self.output_dim,
                wasserstein_lr=distance_config["wasserstein_lr"],
                device=self.device,
                gpu_np_model=self.gpu_np_model,
                lipschitz_constraint_type=distance_config["lipschitz_constraint_type"],
                threshold=distance_config["wasserstein_threshold"],
                n_steps=distance_config["wasserstein_steps"],
                restart_lipschitz=distance_config["restart_lipschitz"],
            )

        elif distance_type == "wasserstein":
            self.distance = WassersteinDistance(
                self.bnn,
                self.np_model,
                self.n_data,
                distance_config["critic_config"],
                distance_config["benchmark"],
                self.output_dim,
                wasserstein_lr=distance_config["wasserstein_lr"],
                device=self.device,
                gpu_np_model=self.gpu_np_model,
                lipschitz_constraint_type=distance_config["lipschitz_constraint_type"],
                threshold=distance_config["wasserstein_threshold"],
                n_steps=distance_config["wasserstein_steps"],
                data_generator=distance_config["discriminator_generator"],
                penalty_coef=distance_config["penalty_coef"],
                restart_lipschitz=distance_config["restart_lipschitz"],
                clip_gradient=distance_config["clip_gradient"],
                clipping_norm=distance_config["clipping_norm"],
            )

        elif distance_type == "discriminator":
            self.distance = DiscriminatorDistance(
                self.bnn,
                self.np_model,
                self.n_data,
                distance_config["critic_config"],
                distance_config["benchmark"],
                self.output_dim,
                discriminator_lr=distance_config["discriminator_lr"],
                device=self.device,
                gpu_np_model=self.gpu_np_model,
                n_steps=distance_config["discriminator_steps"],
                data_generator=distance_config["discriminator_generator"],
                restart_discriminator=distance_config["restart_discriminator"],
                regularize_discriminator=distance_config["regularize_discriminator"],
                penalty_coef=distance_config["penalty_coef"],
                clip_gradient=distance_config["clip_gradient"],
                clipping_norm=distance_config["clipping_norm"],
            )

        elif distance_type == "KL":
            self.distance = KLDistance(
                self.np_model,
                self.output_dim,
                device=self.device,
                gpu_np_model=self.gpu_np_model,
            )

        elif distance_type == "stein_KL":
            self.distance = SteinKLDistance(
                self.np_model,
                self.output_dim,
                distance_config["num_eigs"],
                distance_config["eta"],
                device=self.device,
                gpu_np_model=self.gpu_np_model,
            )

        else:
            raise NotImplementedError(
                "Distance version '{}' does not exist.".format(distance_type)
            )

    def optimize(self, num_iters, n_samples=128, lr=1e-2, print_every=1, debug=False):
        dist_hist = []
        mmd_hist = []

        if self.distance.is_adversarial():
            prior_optimizer = torch.optim.RMSprop(self.bnn.parameters(), lr=lr)
        else:
            prior_optimizer = torch.optim.Adam(self.bnn.parameters(), lr=lr)

        # Initialize distance
        if self.distance.is_adversarial():
            self.distance.reset_distance()

        # Prior loop
        for it in range(1, num_iters + 1):
            # Draw X
            theta, x = self.data_generator.get(self.n_data)
            x = x.to(self.device)
            theta = theta.to(self.device)
            if not self.gpu_np_model:
                x = x.to("cpu")
                theta = theta.to("cpu")

            # Draw functions from NP model
            np_samples = (
                self.np_model.sample_functions(theta, x, n_samples)
                .detach()
                .float()
                .to(self.device)
            )
            if self.output_dim > 1:
                np_samples = np_samples.squeeze()

            if not self.gpu_np_model:
                x = x.to(self.device)
                theta = theta.to(self.device)

            # Draw functions from BNN
            nnet_samples = (
                self.bnn.sample_functions(theta, x, n_samples).float().to(self.device)
            )
            if self.output_dim > 1:
                nnet_samples = nnet_samples.squeeze()

            # print("nnet_samples = {}".format(nnet_samples.sum(dim=0)))
            # print("np_samples = {}".format(np_samples.sum(dim=0)))

            # Optimisation of distance
            if self.distance.is_adversarial():
                self.distance.distance_optimisation(theta, x, n_samples, debug=debug)
                prior_optimizer.zero_grad()

            if self.distance_type == "wasserstein_old":
                dist = self.distance.calculate(nnet_samples, np_samples)
                np_samples = np_samples.transpose(0, 1)
                nnet_samples = nnet_samples.transpose(0, 1)
            else:
                #  It was of size: [n_dim, N, n_out]
                # will be of size: [N, n_dim, n_out]
                nnet_samples = nnet_samples.transpose(0, 1)
                np_samples = np_samples.transpose(0, 1)
                if self.distance.is_adversarial():
                    dist = self.distance.calculate(theta, x, nnet_samples, np_samples)
                else:
                    dist = self.distance.calculate(theta, x, nnet_samples)

            dist.backward()
            # print("value = {}".format(self.bnn.m["npe@model@0@bias"]))
            # print("grad = {}".format(self.bnn.m["npe@model@0@bias"].grad))
            # print("value = {}".format(self.bnn.s_["npe@model@0@bias"]))
            # print("grad = {}".format(self.bnn.s_["npe@model@0@bias"].grad))

            if self.clip_gradient:
                nn.utils.clip_grad_norm_(self.bnn.parameters(), self.clipping_norm)

            prior_optimizer.step()

            with torch.no_grad():
                mmd = self.compute_mmd_given_samples(nnet_samples, np_samples)

            dist_hist.append(float(dist))
            mmd_hist.append(mmd)

            if self.use_wandb:
                wandb.log({"running_dist": dist.item()})
                wandb.log({"running_mmd": mmd})

            if (it % print_every == 0) or it == 1:
                self.print_info(
                    ">>> Iteration # {:3d}: "
                    "Dist {:.4f} "
                    "MMD {:.4f}".format(it, float(dist), float(mmd))
                )

        # Save accumulated list of intermediate wasserstein values
        if debug:
            values = np.array(self.distance.values_log).reshape(-1, 1)
            path = os.path.join(self.out_dir, "dist_intermediate_values.log")
            np.savetxt(path, values, fmt="%.6e")
            self.print_info("Saved intermediate distance values in: " + path)

        return dist_hist, mmd_hist


class GradientBasedPriorMapper(PriorMapper):
    def __init__(
        self,
        np_model,
        bnn,
        data_generator,
        out_dir,
        output_dim=1,
        n_data=256,
        gpu_np_model=False,
        logger=None,
        device="cpu",
        clip_gradient=False,
        clipping_norm=None,
        use_wandb=False,
    ):

        super().__init__(
            np_model,
            bnn,
            data_generator,
            out_dir,
            output_dim=output_dim,
            n_data=n_data,
            gpu_np_model=gpu_np_model,
            logger=logger,
            device=device,
            clip_gradient=clip_gradient,
            clipping_norm=clipping_norm,
            use_wandb=use_wandb,
        )

    def optimize(self):
        pass
