# This code is inspire from https://github.com/SamsungLabs/BayesDLL/blob/main/methods/vi.py

import copy
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from torch import Tensor
from torch.optim import Optimizer
from tqdm import tqdm
from .utils import BayesianScheduler
from ..variational_distributions import GaussianNNParametersDistribution
from ..bnn_utils import freeze_model


class VImodel(nn.Module):
    def __init__(self, net0, config: dict, device, model_path, embedding=None):
        """
        Initialize the variational inference model. Currently the prior must be a normal distribution on each weight and bias of the network. The part where you can specify the prior is not implemented yet. Currently the mean is set to 0 and the std is specified in the config file.

        Args:
            net0: The base neural network.
            config: A dictionary containing the configuration of the model.
            device: The device on which the model will be trained.
        """

        super().__init__()

        self.nets = [None for _ in range(int(config["nb_networks"]))]

        self.config = config
        self.device = device
        self.model_path = model_path

        self.net0 = copy.deepcopy(net0)
        with torch.no_grad():
            for pn, p in self.net0.named_parameters():
                p.copy_(torch.zeros_like(p))

        self.net0 = self.net0.to(self.device)

        # create variational inference model (nn.Module)
        self.model = BNNmodel(self.net0, config).to(self.device)

        # temperature for the cold posterior effect (1= classical posteror likelihood, 0 = non-bayesian network)
        self.temperature = float(config['temperature'])

        if embedding is not None:
            self.embedding = copy.deepcopy(embedding)
            self.embedding = self.embedding.to(self.device)
            self.embedding_has_parameters = True
        else:
            self.embedding = nn.Identity()
            self.embedding_has_parameters = False

    def forward(self, theta: Tensor, x: Tensor) -> Tensor:
        """
        Forward pass of the model. This method is used to compute the posterior approximation. This method should never be used outside of training as it computes the posterior using only one network.

        Args:
            theta: The parameters of the network.
            x: The input of the network.

        Returns:
            (Tensor): The posterior approximation using only one network.
        """
        params = self.load_model()
        x = self.embedding(x)
        res = torch.func.functional_call(
            self.net0, params, (theta, x), strict=False)
        return res

    def fill_network(self, params: dict) -> nn.Module:
        """
        Fill the network with the parameters from the dictionary. This method should never be used to fill the network during training because the gradients will be backpropagated correctly. This method should only be used for forward passes.

        Args:
            params: A dictionary containing the parameters of the network.

        Returns:
            (nn.Module): The network filled with the parameters.

        """
        net = copy.deepcopy(self.net0)
        with torch.no_grad():
            for name, param in net.named_parameters():
                param.copy_(params[name])
        return net

    def load_model(self, id=None):
        """
        Loads a model drawn from the posterior of the weights. If id is None, a random model is drawn from the posterior. If id is specified, the model with the specified id is loaded. If the model does not exist, it is created and saved.

        Args:
            id: The id of the model to be loaded. If None, a random model is drawn from the posterior.

        Returns:
            If id is None, a dictionary containing the parameters of the model to be used as a functional call. If id is specified, a nn.Module containing the parameters of the model.
        """
        if id is not None:
            # check if model exists
            assert id < int(self.config["nb_networks"])

            if self.nets[id] is not None:
                return self.nets[id]

            if not os.path.exists(os.path.join(self.model_path, "net_{}.pt".format(id))):
                params = self.model.sample_network()
                net = self.fill_network(params)
                torch.save(net.state_dict(), os.path.join(
                    self.model_path, "net_{}.pt".format(id)))
                self.nets[id] = net

            else:
                net = copy.deepcopy(self.net0)
                net.load_state_dict(torch.load(os.path.join(
                    self.model_path, "net_{}.pt".format(id)), map_location=self.device))
                self.nets[id] = net
        else:
            net = self.model.sample_network()
        return net

    def save(self):
        """
        Saves the model.
        """
        torch.save(self.embedding.state_dict(), os.path.join(
            self.model_path, "embedding.pt"))
        torch.save(self.model.state_dict(), os.path.join(
            self.model_path, "VIparametrisaton.pt"))

    def load(self):
        """
        Loads the model.
        """
        if self.embedding_has_parameters:
            self.embedding.load_state_dict(torch.load(os.path.join(
                self.model_path, "embedding.pt"), map_location=self.device))
        self.model.load_state_dict(torch.load(os.path.join(
            self.model_path, "VIparametrisaton.pt"), map_location=self.device))

    def set_temperature(self, temperature):
        """
        Sets the temperature of the posterior approximation.

        Args:
            temperature: The temperature of the posterior approximation.
        """
        self.temperature = temperature

    def train_models(self, train_set, val_set, config, cls_loss, prior_distribution):
        """
        Trains the model.

        Args:
            train_set: The training set.
            val_set: The validation set.
            config: A dictionary containing the configuration of the model.
            cls_loss: The loss function to be used.
        """

        # If the prior has been tuned, init the model to this tuned prior
        if config["optimize_prior"]:
            self.model = BNNmodel(
                self.net0, config, prior_distribution).to(self.device)

        if "min_prior_std" in config.keys():
            min_prior_std = config["min_prior_std"]
        else:
            min_prior_std = None

        prior_distribution.eval()

        config = self.config
        max_temp = float(config["max_temperature"])

        learning_rate = float(config["learning_rate"])
        epochs = config["epochs"]
        loss_fct = cls_loss(self)

        params = [param for name, param in self.named_parameters()
                  if param.requires_grad]
        optimizer = optim.Adam(params, lr=learning_rate)
        step = VIstep(optimizer, self.model)
        scheduler = BayesianScheduler(optimizer, verbose=True, min_lr=float(
            config["min_lr"]), patience=int(config["patience"]), temperature=config["temperature"], max_temperature=max_temp, update_prior_weight=self.set_temperature)

        train_size = len(train_set)
        print("Training size: ", train_size)
        self.model.set_train_size(train_size)

        best_loss = float("inf")
        self.train()

        with tqdm(range(epochs), unit='epochs') as tq:
            for epoch in tq:

                self.model.train()

                losses = []
                for theta, x in train_set:

                    loss_nll = loss_fct(
                        theta.to(self.device), x.to(self.device))
                    losses.append(
                        step(loss_nll, prior_distribution, self.temperature, min_prior_std=min_prior_std))

                losses = torch.stack(losses)

                self.model.eval()
                val_losses = []
                with torch.no_grad():
                    for theta, x in val_set:
                        val_losses.append(loss_fct(
                            theta.to(self.device), x.to(self.device)))
                    kl_loss = self.model.get_kl_loss(
                        prior_distribution, temperature=self.temperature, min_prior_std=min_prior_std).item()
                    max_kl_loss = self.model.get_kl_loss(
                        prior_distribution, temperature=max_temp, min_prior_std=min_prior_std)

                val_losses = torch.stack(val_losses)

                loss = losses.mean().item()
                val_loss = val_losses.mean().item()

                scheduler.step(val_loss + kl_loss)

                tq.set_postfix(loss=loss, val_loss=val_loss, kl_loss=kl_loss)

                if val_loss + max_kl_loss < best_loss:
                    best_loss = val_loss + max_kl_loss
                    best_weights = copy.deepcopy(self.model.state_dict())

        self.model.load_state_dict(best_weights)

    def log_prob(self, theta, x, id_net=None):
        """
        Computes the log probability of the data given the parameters of the network. If id_net is None, the log probability is computed using all the networks. If id_net is specified, the log probability is computed using only the specified network. This method should never be used during training as the gradients will not be backpropagated correctly.

        Args:
            theta: The parameters of the network.
            x: The input of the network.
            id_net: The id of the network to be used. If None, all the networks are used.

        Returns:
            (Tensor): The log probability of the data given the parameters of the network(s).
        """

        nb_networks = int(self.config["nb_networks"])

        self.model.eval()

        outputs = []
        x = x.to(self.device)
        theta = theta.to(self.device)

        if id_net is not None:
            net = self.load_model(id_net)
            return net(theta, x)

        for id in range(nb_networks):
            net = self.load_model(id)
            outputs.append(net(theta, x))

        outputs = torch.stack(outputs)
        outputs = torch.logsumexp(outputs, dim=0) - \
            torch.log(torch.tensor(nb_networks))
        return outputs

    def sample(self, x, shape, id_net=None):
        """
        Samples from the posterior of the network. If id_net is None, the sample is drawn from all the networks. If id_net is specified, the sample is drawn from the specified network.

        Args:
            x: The input of the network.
            shape: The shape of the sample.
            id_net: The id of the network to be used. If None, all the networks are used.

        Returns:
            (Tensor): The sample from the posterior of the network(s).
        """
        if id_net is not None:
            net = self.load_model(id_net)
            return net.sample(x, shape)

        nb_networks = int(self.config["nb_networks"])
        nb_samples = np.prod(np.array(shape))

        samples = []
        samples_per_network = np.random.multinomial(
            nb_samples, [1/nb_networks]*nb_networks)

        for id, nb_sample in enumerate(samples_per_network):
            net = self.load_model(id)
            if nb_sample > 0:
                samples.append(net.sample(x, (nb_sample, )))

        samples = torch.cat(samples, dim=0).view(*shape, -1)
        return samples


class BNNmodel(nn.Module):
    '''
    Variational inference model.

    Represents q(theta) = N(theta; m, Diag(v)) where v = s^2 (s = torch.exp(log_s)).
    '''

    def __init__(self, net0: nn.Module, config, prior_distribution=None):
        """
        Initialize the variational inference model.

        Args:
            net0: Tne neural network containing intial weights' means.
            config: the script config.
        """

        super().__init__()

        if prior_distribution:
            self.VIparams = GaussianNNParametersDistribution(
                net0, shared=False, init_distribution=prior_distribution)
        else:
            self.VIparams = GaussianNNParametersDistribution(
                net0, shared=False, low_variance_init=True, std_init_value=float(config['std_init_value']))

        self.train_size = None

    def sample_network(self) -> dict:
        """
        Draw weights from the variational posterior and return a dictionary of parameters.

        Returns:
            (dict): A dictionary containing the parameters of the network.
        """

        return self.VIparams.sample_network()

    def set_train_size(self, train_size):
        """
        This method has to be called before training. It sets the size of the training set. This is used to compute the KL divergence.

        Args:
            train_size: The size of the training set.
        """
        self.train_size = train_size

    def get_kl_loss(self, prior_distribution, temperature=1.0, min_prior_std=None) -> Tensor:
        """

        This method computes the KL divergence between the variational posterior and the prior. It is used during training. The temperature parameter is used to compute the cold posterior effect. If temperature=1, the KL divergence is computed using the classical posterior likelihood. If temperature>=1, the prior has more influence on the posterior on the weights. If temperature<1, the posterior is more influenced by the likelihood.

        Args:
            temperature: The temperature of the posterior.

        Returns:
            (Tensor): The KL divergence between the variational posterior and the prior.

        """

        assert self.train_size is not None

        return self.VIparams.get_kl_loss(prior_distribution, self.train_size, temperature, min_prior_std=min_prior_std)


class VIstep(object):
    """
    Performs a step of the variational inference optimization and includes the KL divergence in the loss.
    """

    def __init__(self, optimizer: Optimizer, model: BNNmodel, clip: float = None):

        self.optimizer = optimizer
        self.parameters = [
            p
            for group in optimizer.param_groups
            for p in group['params']
        ]
        self.clip = clip
        self.model = model

    def __call__(self, loss: Tensor, prior_distribution, temperature: float, min_prior_std=None) -> Tensor:

        if loss.isfinite().all():
            self.optimizer.zero_grad()

            kl_loss = self.model.get_kl_loss(
                prior_distribution, temperature, min_prior_std=min_prior_std)
            new_loss = loss + kl_loss
            new_loss.backward()

            if self.clip is None:
                self.optimizer.step()
            else:
                norm = nn.utils.clip_grad_norm_(self.parameters, self.clip)
                if norm.isfinite():
                    self.optimizer.step()

        return loss.detach()
