# This code is inspired from https://github.com/SamsungLabs/BayesDLL/blob/main/src/bayesdll/sgld.py and https://github.com/JavierAntoran/Bayesian-Neural-Networks

import copy
import os
import shutil
import math
import torch
import torch.nn as nn
import numpy as np

from tqdm import tqdm
from ..sghmc import H_SA_SGHMC
from .utils import BayesianScheduler


class HMCmodel(nn.Module):

    def __init__(self, net, config: dict, device, model_path):

        super().__init__()

        self.net = net
        self.config = config

        self.expected_nb_models = config["samples_per_chain"]
        self.nb_chains = config["nb_chains"]
        self.model_path = model_path
        self.device = device

        self.grad_buff = []

    def forward(self, theta, x):
        self.net = self.load_model()
        return self.net(theta, x)

    def get_loss_fct(self, config, cls_loss):
        return cls_loss

    def load_model(self, chain_id, index):

        net = copy.deepcopy(self.net)
        net.load_state_dict(torch.load(os.path.join(
            self.model_path, "net_{}_{}.pt".format(chain_id, index))))
        return net

    def save_model(self, chain_id, index):
        torch.save(self.net.state_dict(), os.path.join(
            self.model_path, "net_{}_{}.pt".format(chain_id, index)))

    def delete_models(self, chain_id):
        shutil.rmtree(os.path.join(self.model_path,
                      "chain_{}".format(chain_id)), ignore_errors=True)
        shutil.rmtree(os.path.join(self.model_path,
                      "trained_{}.pt".format(chain_id)), ignore_errors=True)
        shutil.rmtree(os.path.join(self.model_path,
                      "bnn_prior.pt"), ignore_errors=True)

    def set_temperature(self, temperature):
        self.temperature = temperature

    def train_models(self, train_set, val_set, config, chain_id, cls_loss, bnn_prior):
        self.delete_models(chain_id)

        self.delete_models(chain_id)

        learning_rate = float(config["learning_rate"])
        momentum_decay = float(config["momentum_decay"])
        epochs = config["epochs"]
        burn_in_epochs = config["burn_in_epochs"]
        samples_per_chain = config["samples_per_chain"]

        # Save one model per epoch
        save_every = math.ceil(
            (len(train_set)*(epochs-burn_in_epochs) /
             config["train_batch_size"]) / samples_per_chain
        )

        resample_momentum_every = config["resample_momentum_every"]
        resample_prior_every = config["resample_prior_every"]

        if "min_prior_std" in config.keys():
            min_prior_std = config["min_prior_std"]
        else:
            min_prior_std = None

        if "init_hmc_to_prior" in config.keys() and config["init_hmc_to_prior"]:
            with torch.no_grad():
                for name, param in self.net.named_parameters():
                    param.copy_(bnn_prior.m[bnn_prior.transform_param_name(name)])

        self.train()
        optimizer = H_SA_SGHMC(
            params=self.net.parameters(), lr=learning_rate)
        scheduler = BayesianScheduler(optimizer, verbose=True, min_lr=float(
            config["min_lr"]), patience=int(config["patience"]), temperature=config["temperature"], max_temperature=config["max_temperature"], update_prior_weight=self.set_temperature)

        n_train = len(train_set)
        loss_fct = cls_loss(self.net)

        train_losses = []
        current_model_index = 0
        current_iteration = 0
        batch_id = 0

        with tqdm(range(epochs), unit="epochs") as tq:

            for epoch in tq:

                self.train()

                # Perform train steps
                for theta, x in train_set:

                    optimizer.zero_grad()
                    loss = loss_fct(theta.to(self.device), x.to(self.device))
                    loss = loss * n_train
                    loss.backward()

                    # Add prior_grad
                    for name, param in self.net.named_parameters():
                        grad = bnn_prior.get_prior_grad(name, param.data, min_std=min_prior_std)
                        param.grad.data.add_(grad)

                    if "hmc_grad_clipping" in config.keys() and config["hmc_grad_clipping"]:
                        # Gradient buffer to allow for dynamic clipping and prevent explosions
                        if len(self.grad_buff) < 50:
                            max_grad = 1e20
                        else:
                            max_grad = np.mean(self.grad_buff) + config["hmc_grad_clipping_factor"] * np.std(self.grad_buff)
                        if len(self.grad_buff) > 1000:
                            self.grad_buff.pop(0)

                        # Clipping to prevent explosions
                        self.grad_buff.append(nn.utils.clip_grad_norm_(parameters=self.net.parameters(),
                                                                       max_norm=max_grad, norm_type=2).cpu().item())
                        if self.grad_buff[-1] >= max_grad:
                            print(max_grad, self.grad_buff[-1])
                            self.grad_buff.pop()

                    if epoch < burn_in_epochs:
                        optimizer.step(
                            burn_in=True,
                            resample_momentum=(
                                current_iteration % resample_momentum_every == 0
                            ),
                            resample_prior=(
                                current_iteration % resample_prior_every == 0
                            ),
                        )
                    else:
                        optimizer.step(
                            burn_in=False,
                            resample_momentum=(
                                current_iteration % resample_momentum_every == 0
                            ),
                            resample_prior=(
                                current_iteration % resample_prior_every == 0
                            ),
                        )
                        if batch_id % save_every == 0:
                            self.save_model(chain_id, current_model_index)
                            current_model_index += 1

                        batch_id += 1

                    current_iteration += 1
                    train_losses.append(loss.cpu().item())

                with torch.no_grad():
                    val_losses = []
                    for theta, x in val_set:
                        loss = loss_fct(theta.to(self.device),
                                        x.to(self.device))
                        val_losses.append(loss.cpu().item())

                    val_loss = torch.Tensor(val_losses).mean().item()

                scheduler.step(val_loss)

                tq.set_postfix(loss=torch.Tensor(
                    train_losses).mean().item() / n_train, val_loss=val_loss)

        torch.save(
            train_losses,
            os.path.join(self.model_path,
                         "train_losses_{}.pt".format(chain_id)),
        )
        torch.save(
            torch.Tensor([]),
            os.path.join(self.model_path, "trained_{}.pt".format(chain_id)),
        )

    def get_total_nb_models(self):
        return self.expected_nb_models * self.nb_chains

    def save(self):
        pass

    def load(self):
        """
        Loads the model.
        """
        pass

    def log_prob(self, theta, x, id_net=None):
        outputs = []
        x = x.to(self.device)
        theta = theta.to(self.device)

        if id_net is not None:
            net = self.load_model(
                id_net//self.expected_nb_models, id_net % self.expected_nb_models)
            return net(theta, x)

        for i in range(self.nb_chains):
            for j in range(self.expected_nb_models):
                net = self.load_model(i, j)
                outputs.append(net(theta, x))

        outputs = torch.stack(outputs, dim=-1)
        outputs = torch.logsumexp(outputs, dim=-1) - np.log(
            self.expected_nb_models * self.nb_chains
        )
        return outputs

    def sample(self, x, shape, id_net=None):
        if id_net is not None:
            net = self.load_model(
                id_net//self.expected_nb_models, id_net % self.expected_nb_models)
            return net.sample(x, shape)

        nb_samples = np.prod(np.array(shape))

        nb_networks = self.nb_chains * self.expected_nb_models
        samples_per_network = np.random.multinomial(
            nb_samples, [1/nb_networks]*nb_networks)
        samples = []

        for i in range(self.nb_chains):
            for j in range(self.expected_nb_models):
                net = self.load_model(i, j)
                nb_sample = samples_per_network[i*self.expected_nb_models+j]
                if nb_sample > 0:
                    sample = net.sample(x, (nb_sample,))
                    samples.append(sample)

        samples = torch.cat(samples, dim=0).view(*shape, -1)

        return samples
