"""
implementation of SWAG, from author's original at
https://github.com/wjmaddox/swa_gaussian/blob/master/swag/posteriors/swag.py
"""

from typing import Any, Dict, List, Tuple, Type

import gpytorch  # type: ignore
import numpy as np  # type: ignore
import torch
from gpytorch.distributions import MultivariateNormal  # type: ignore
from gpytorch.lazy import DiagLazyTensor  # type: ignore
from gpytorch.lazy import AddedDiagLazyTensor, RootLazyTensor
from torch import nn
from torch.distributions.normal import Normal

TensorList = List[torch.Tensor]


def flatten(lst: TensorList) -> torch.Tensor:
    tmp = [i.contiguous().view(-1, 1) for i in lst]
    return torch.cat(tmp).view(-1)


def unflatten_like(vector: torch.Tensor, likeTensorList: TensorList) -> TensorList:
    # Takes a flat torch.tensor and unflattens it to a list of torch.tensors
    #    shaped like likeTensorList
    outList = []
    i = 0
    for tensor in likeTensorList:
        # n = module._parameters[name].numel()
        n = tensor.numel()
        outList.append(vector[:, i : i + n].view(tensor.shape))
        i += n
    return outList


def swag_parameters(
    module: nn.Module, params: List[nn.Parameter], no_cov_mat: bool = True
) -> None:
    for name in list(module._parameters.keys()):  # type: ignore
        if module._parameters[name] is None:  # type: ignore
            continue
        data = module._parameters[name].data  # type: ignore
        module._parameters.pop(name)  # type: ignore
        module.register_buffer("%s_mean" % name, data.new(data.size()).zero_())  # type: ignore
        module.register_buffer("%s_sq_mean" % name, data.new(data.size()).zero_())  # type: ignore

        if no_cov_mat is False:
            module.register_buffer(
                "%s_cov_mat_sqrt" % name, data.new_empty((0, data.numel())).zero_()
            )

            params.append((module, name))  # type: ignore


class SWAG(torch.nn.Module):
    def __init__(
        self,
        base: Type[nn.Module],
        device: torch.device,
        no_cov_mat: bool = True,
        max_num_models: int = 0,
        var_clamp: float = 1e-30,
        model_args: List[Any] = [],
        model_kwargs: Dict[str, Any] = {},
    ) -> None:
        super(SWAG, self).__init__()

        self.register_buffer("n_models", torch.zeros([1], dtype=torch.long))
        self.params: List[nn.Parameter] = list()

        self.no_cov_mat = no_cov_mat
        self.max_num_models = max_num_models

        self.var_clamp = var_clamp
        self.device = device

        self.base = base(*model_args, **model_kwargs)  # type: ignore
        self.base.apply(
            lambda module: swag_parameters(
                module=module, params=self.params, no_cov_mat=self.no_cov_mat
            )
        )

    def forward(self, *args: List[Any], **kwargs: Dict[str, Any]) -> nn.Module:
        return self.base(*args, **kwargs)  # type: ignore

    def mc(self, x: torch.Tensor, samples: int) -> Tuple[torch.Tensor, torch.Tensor]:

        mus = torch.zeros(samples, x.size(0), device=x.device)
        logvars = torch.zeros(samples, x.size(0), device=x.device)
        for i in range(samples):
            self.sample(cov=True)
            mus[i], logvars[i] = self(x)

        return mus, logvars

    def mc_class(self, x: torch.Tensor, samples: int) -> torch.Tensor:
        out = torch.zeros(samples, x.size(0), self.base.y_dim, device=x.device)  # type: ignore
        for i in range(samples):
            self.sample(cov=True)
            out[i] = self(x)

        return out

    def sample(
        self,
        scale: float = 1.0,
        cov: bool = False,
        block: bool = False,
        fullrank: bool = True,
    ) -> None:
        if not block:
            self.sample_fullrank(scale, cov, fullrank)
        else:
            self.sample_blockwise(scale, cov, fullrank)

    def sample_blockwise(self, scale: float, cov: bool, fullrank: bool) -> None:
        for module, name in self.params:
            mean = module.__getattr__("%s_mean" % name)  # type: ignore

            sq_mean = module.__getattr__("%s_sq_mean" % name)  # type: ignore
            eps = torch.randn_like(mean)

            var = torch.clamp(sq_mean - mean ** 2, self.var_clamp)

            scaled_diag_sample = scale * torch.sqrt(var) * eps

            if cov is True:
                cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)  # type: ignore
                eps = cov_mat_sqrt.new_empty((cov_mat_sqrt.size(0), 1)).normal_()
                cov_sample = (
                    scale / ((self.max_num_models - 1) ** 0.5)
                ) * cov_mat_sqrt.t().matmul(eps).view_as(mean)

                if fullrank:
                    w = mean + scaled_diag_sample + cov_sample
                else:
                    w = mean + scaled_diag_sample

            else:
                w = mean + scaled_diag_sample

            module.__setattr__(name, w)  # type: ignore

    def sample_fullrank(self, scale: float, cov: bool, fullrank: bool) -> None:
        scale_sqrt = scale ** 0.5

        mean_list = []
        sq_mean_list = []

        if cov:
            cov_mat_sqrt_list = []

        for (module, name) in self.params:
            mean = module.__getattr__("%s_mean" % name)  # type: ignore
            sq_mean = module.__getattr__("%s_sq_mean" % name)  # type: ignore

            if cov:
                cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)  # type: ignore
                cov_mat_sqrt_list.append(cov_mat_sqrt.cpu())

            mean_list.append(mean.cpu())
            sq_mean_list.append(sq_mean.cpu())

        mean = flatten(mean_list)
        sq_mean = flatten(sq_mean_list)

        # draw diagonal variance sample
        var = torch.clamp(sq_mean - mean ** 2, self.var_clamp)
        var_sample = var.sqrt() * torch.randn_like(var, requires_grad=False)

        # if covariance draw low rank sample
        if cov:
            cov_mat_sqrt = torch.cat(cov_mat_sqrt_list, dim=1)

            cov_sample = cov_mat_sqrt.t().matmul(
                cov_mat_sqrt.new_empty(
                    (cov_mat_sqrt.size(0),), requires_grad=False
                ).normal_()
            )
            cov_sample /= (self.max_num_models - 1) ** 0.5

            rand_sample = var_sample + cov_sample
        else:
            rand_sample = var_sample

        # update sample with mean and scale
        sample = mean + scale_sqrt * rand_sample
        sample = sample.unsqueeze(0)
        # print(f"sample: {sample}")

        # unflatten new sample like the mean sample
        samples_list = unflatten_like(sample, mean_list)

        for (module, name), sample in zip(self.params, samples_list):
            module.__setattr__(name, sample.to(self.device))  # type: ignore

    def collect_model(self, base_model: nn.Module) -> None:
        for (module, name), base_param in zip(self.params, base_model.parameters()):
            mean = module.__getattr__("%s_mean" % name)  # type: ignore
            sq_mean = module.__getattr__("%s_sq_mean" % name)  # type: ignore

            # fmt: off
            # first moment
            mean = mean * self.n_models.item() / (self.n_models.item() + 1.0) + base_param.data / (self.n_models.item() + 1.0)  # type: ignore

            # second moment
            sq_mean = sq_mean * self.n_models.item() / (self.n_models.item() + 1.0) + base_param.data ** 2 / (self.n_models.item() + 1.0)  # type: ignore
            # fmt: on

            # square root of covariance matrix
            if self.no_cov_mat is False:
                cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)  # type: ignore

                # block covariance matrices, store deviation from current mean
                dev = (base_param.data - mean).view(-1, 1)
                cov_mat_sqrt = torch.cat((cov_mat_sqrt, dev.view(-1, 1).t()), dim=0)

                # remove first column if we have stored too many models
                if (self.n_models.item() + 1) > self.max_num_models:  # type: ignore
                    cov_mat_sqrt = cov_mat_sqrt[1:, :]
                module.__setattr__("%s_cov_mat_sqrt" % name, cov_mat_sqrt)

            module.__setattr__("%s_mean" % name, mean)
            module.__setattr__("%s_sq_mean" % name, sq_mean)
        self.n_models.add_(1)  # type: ignore

    def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True) -> None:
        if not self.no_cov_mat:
            n_models = state_dict["n_models"].item()
            rank = min(n_models, self.max_num_models)
            for module, name in self.params:
                mean = module.__getattr__("%s_mean" % name)  # type: ignore
                module.__setattr__(
                    "%s_cov_mat_sqrt" % name,
                    mean.new_empty((rank, mean.numel())).zero_(),
                )
        super(SWAG, self).load_state_dict(state_dict, strict)

    def export_numpy_params(
        self, export_cov_mat: bool = False
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        mean_list = []
        sq_mean_list = []
        cov_mat_list = []

        for module, name in self.params:
            mean_list.append(module.__getattr__("%s_mean" % name).cpu().numpy().ravel())  # type: ignore
            sq_mean_list.append(
                module.__getattr__("%s_sq_mean" % name).cpu().numpy().ravel()  # type: ignore
            )
            if export_cov_mat:
                cov_mat_list.append(
                    module.__getattr__("%s_cov_mat_sqrt" % name).cpu().numpy().ravel()  # type: ignore
                )
        mean = np.concatenate(mean_list)
        sq_mean = np.concatenate(sq_mean_list)
        var = sq_mean - np.square(mean)

        if export_cov_mat:
            return mean, var, cov_mat_list  # type: ignore
        else:
            return mean, var

    def import_numpy_weights(self, w: np.array) -> None:
        k = 0
        for module, name in self.params:
            mean = module.__getattr__("%s_mean" % name)  # type: ignore
            s = np.prod(mean.shape)
            module.__setattr__(name, mean.new_tensor(w[k : k + s].reshape(mean.shape)))  # type: ignore
            k += s

    def generate_mean_var_covar(self) -> Tuple[TensorList, ...]:
        mean_list = []
        var_list = []
        cov_mat_root_list = []
        for module, name in self.params:
            mean = module.__getattr__("%s_mean" % name)  # type: ignore
            sq_mean = module.__getattr__("%s_sq_mean" % name)  # type: ignore
            cov_mat_sqrt = module.__getattr__("%s_cov_mat_sqrt" % name)  # type: ignore

            mean_list.append(mean)
            var_list.append(sq_mean - mean ** 2.0)
            cov_mat_root_list.append(cov_mat_sqrt)
        return mean_list, var_list, cov_mat_root_list

    def compute_ll_for_block(
        self,
        vec: TensorList,
        mean: TensorList,
        var: TensorList,
        cov_mat_root: torch.Tensor,
    ) -> torch.Tensor:
        t_vec = flatten(vec)
        t_mean = flatten(mean)
        t_var = flatten(var)

        cov_mat_lt = RootLazyTensor(cov_mat_root.t())
        var_lt = DiagLazyTensor(t_var + 1e-6)
        covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)
        qdist = MultivariateNormal(t_mean, covar_lt)

        with gpytorch.settings.num_trace_samples(
            1
        ) and gpytorch.settings.max_cg_iterations(25):
            return qdist.log_prob(t_vec)  # type: ignore

    def block_logdet(self, var: TensorList, cov_mat_root: torch.Tensor) -> torch.Tensor:
        t_var = flatten(var)

        cov_mat_lt = RootLazyTensor(cov_mat_root.t())
        var_lt = DiagLazyTensor(t_var + 1e-6)
        covar_lt = AddedDiagLazyTensor(var_lt, cov_mat_lt)

        return covar_lt.log_det()  # type: ignore

    def block_logll(
        self, param_list, mean_list, var_list, cov_mat_root_list
    ) -> torch.Tensor:
        full_logprob = 0.0
        for i, (param, mean, var, cov_mat_root) in enumerate(
            zip(param_list, mean_list, var_list, cov_mat_root_list)
        ):
            # print('Block: ', i)
            block_ll = self.compute_ll_for_block(param, mean, var, cov_mat_root)
            full_logprob += block_ll  # type: ignore

        return full_logprob  # type: ignore

    def full_logll(self, param_list, mean_list, var_list, cov_mat_root_list):
        cov_mat_root = torch.cat(cov_mat_root_list, dim=1)
        mean_vector = flatten(mean_list)
        var_vector = flatten(var_list)
        param_vector = flatten(param_list)
        return self.compute_ll_for_block(
            param_vector, mean_vector, var_vector, cov_mat_root
        )

    def compute_logdet(self, block=False):
        _, var_list, covar_mat_root_list = self.generate_mean_var_covar()

        if block:
            full_logdet = 0
            for (var, cov_mat_root) in zip(var_list, covar_mat_root_list):
                block_logdet = self.block_logdet(var, cov_mat_root)
                full_logdet += block_logdet
        else:
            var_vector = flatten(var_list)
            cov_mat_root = torch.cat(covar_mat_root_list, dim=1)
            full_logdet = self.block_logdet(var_vector, cov_mat_root)

        return full_logdet

    def diag_logll(self, param_list, mean_list, var_list):
        logprob = 0.0
        for param, mean, scale in zip(param_list, mean_list, var_list):
            logprob += Normal(mean, scale).log_prob(param).sum()
        return logprob

    def compute_logprob(self, vec=None, block=False, diag=False):
        mean_list, var_list, covar_mat_root_list = self.generate_mean_var_covar()

        if vec is None:
            param_list = [getattr(param, name) for param, name in self.params]
        else:
            param_list = unflatten_like(vec, mean_list)

        if diag:
            return self.diag_logll(param_list, mean_list, var_list)
        elif block is True:
            return self.block_logll(
                param_list, mean_list, var_list, covar_mat_root_list
            )
        else:
            return self.full_logll(param_list, mean_list, var_list, covar_mat_root_list)
