import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.distributions import Dirichlet, Normal
from resnet import ResNet, DropoutResNet, RadialResNet, Rank1ResNet


def smooth_softmax(raw_logits, gamma=1e-4):
    return (1. - gamma) * F.softmax(raw_logits, dim=-1) + gamma * (1. / raw_logits.shape[-1])


def estimate_dirichlet(raw_logits, max_precision):
    n_batch, n_samples, n_classes = raw_logits.shape
    min_precision = n_classes
    # apply softmax and some smoothing for numerical robustness
    smooth_probs = smooth_softmax(raw_logits)
    mean = smooth_probs.mean(dim=-2)
    log_p_bar = smooth_probs.log().mean(dim=-2)
    if n_samples == 1:
        precision = max_precision
    else:
        precision = -((n_classes - 1.) / 2.) / (mean * (log_p_bar - mean.log())).sum(dim=-1, keepdim=True)
        precision.clamp_(min_precision, max_precision)
        variance = 1. / precision
        while True:
            ds = n_samples * (precision.digamma()
                              - (mean * (precision * mean).digamma()).sum(dim=-1, keepdim=True)
                              + (mean * log_p_bar).sum(dim=-1, keepdim=True))
            dds = n_samples * (precision.polygamma(1)
                               - (mean ** 2 * (precision * mean).polygamma(1)).sum(dim=-1, keepdim=True))
            temp = variance
            variance = variance + (variance ** 2) * ds / dds
            if torch.isnan(variance.mean()):
                raise Exception("NAN encountered.")
            precision = (1. / variance).clamp(min_precision, max_precision)
            if torch.allclose(temp, variance, atol=1e-5):
                break
    return precision * mean


class MLPMAPClassifier(nn.Module):
    def __init__(self, n_classes, n_features, max_precision,
                 hidden_dims=(), bias=True, activation=nn.ReLU):
        super().__init__()
        self.n_classes = n_classes
        self.max_precision = max_precision
        dims = [n_features] + list(hidden_dims) + [n_classes]
        layers = [nn.Linear(dims[0], dims[1], bias=bias)]
        for j in range(1, len(dims) - 1):
            layers.append(activation())
            layers.append(nn.Linear(dims[j], dims[j+1], bias=bias))
        self.layers = nn.Sequential(*layers)

    def forward(self, input):
        raw_logits = self.layers(input.view(input.shape[0], -1))
        return raw_logits

    def nll_loss(self, raw_logits, labels):
        return F.cross_entropy(raw_logits, labels)

    def predict(self, data):
        return estimate_dirichlet(self.forward(data).unsqueeze(dim=-2), self.max_precision)


class MLPMAPFVIClassifier(MLPMAPClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z)).mean()
        return fkl


class MLPDropoutClassifier(nn.Module):
    def __init__(self, n_classes, n_features, max_precision,
                 hidden_dims=(), bias=True, activation=nn.ReLU, dropout=0.2):
        super().__init__()
        self.n_classes = n_classes
        self.max_precision = max_precision
        dims = [n_features] + list(hidden_dims) + [n_classes]
        layers = [nn.Linear(dims[0], dims[1], bias=bias)]
        for j in range(1, len(dims) - 1):
            layers.append(activation())
            layers.append(nn.Dropout(p=dropout))
            layers.append(nn.Linear(dims[j], dims[j + 1], bias=bias))
        self.layers = nn.Sequential(*layers)

    def forward(self, input):
        raw_logits = self.layers(input.view(input.shape[0], -1))
        return raw_logits

    def nll_loss(self, raw_logits, labels):
        return F.cross_entropy(raw_logits, labels)

    def predict(self, data, n_samples=100):
        raw_logits = self.forward(data).unsqueeze(dim=-2)
        for i in range(n_samples - 1):
            raw_logits = torch.cat([raw_logits, self.forward(data).unsqueeze(dim=-2)], dim=-2)
        return estimate_dirichlet(raw_logits, self.max_precision)


class MLPDropoutFVIClassifier(MLPDropoutClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z)).mean()
        return fkl


class MLPEnsembleLayer(nn.Module):
    def __init__(self, n_models, in_features, out_features, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(n_models, out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.zeros(n_models, out_features))
        else:
            self.bias = None
        self._reset_parameters()

    def _reset_parameters(self):
        # adapt He initialization for einsum
        std = math.sqrt(2.) / math.sqrt(self.weight.shape[-1])
        bound = math.sqrt(3.) * std
        init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            bound = 1. / math.sqrt(self.weight.shape[-1])
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        # b: batch, i: model, j: output, k: input
        output = torch.einsum('ijk, bik -> bij', self.weight, input)
        if self.bias is not None:
            output = output + self.bias
        return output


class MLPEnsembleClassifier(nn.Module):
    def __init__(self, n_classes, n_features, max_precision,
                 hidden_dims=(), bias=True, activation=nn.ReLU, n_models=10):
        super().__init__()
        self.n_classes = n_classes
        self.max_precision = max_precision
        self.n_models = n_models
        dims = [n_features] + list(hidden_dims) + [n_classes]
        layers = [MLPEnsembleLayer(n_models, dims[0], dims[1], bias=bias)]
        for j in range(1, len(dims) - 1):
            layers.append(activation())
            layers.append(MLPEnsembleLayer(n_models, dims[j], dims[j+1], bias=bias))
        self.layers = nn.Sequential(*layers)

    def forward(self, input):
        raw_logits = self.layers(input.view(input.shape[0], -1).repeat(self.n_models, 1, 1).transpose(0, 1))
        return raw_logits

    def nll_loss(self, raw_logits, labels):
        nll = torch.tensor(0., device=raw_logits.device)
        for j in range(self.n_models):
            nll += F.cross_entropy(raw_logits[:, j, :], labels) / self.n_models
        return nll

    def predict(self, data):
        return estimate_dirichlet(self.forward(data), self.max_precision)


class MLPEnsembleFVIClassifier(MLPEnsembleClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits, self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = torch.tensor(0., device=raw_logits.device)
        for j in range(self.n_models):
            fkl += (var_post.log_prob(z[:, j, :]) - prior.log_prob(z[:, j, :])).mean() / self.n_models
        return fkl


class CNNMAPClassifier(ResNet):
    def __init__(self, n_classes, _, max_precision,
                 block, n_blocks):
        self.n_classes = n_classes
        self.max_precision = max_precision
        super().__init__(block, n_blocks, num_classes=n_classes)

    def nll_loss(self, raw_logits, labels):
        return F.cross_entropy(raw_logits, labels)

    def predict(self, data):
        return estimate_dirichlet(self.forward(data).unsqueeze(dim=-2), self.max_precision)


class CNNMAPFVIClassifier(CNNMAPClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z))
        return fkl.mean()


class CNNDropoutClassifier(DropoutResNet):
    def __init__(self, n_classes, _, max_precision,
                 block, n_blocks, dropout):
        self.n_classes = n_classes
        self.max_precision = max_precision
        self.dropout = dropout
        super().__init__(block, n_blocks, dropout, num_classes=n_classes)

    def nll_loss(self, raw_logits, labels):
        return F.cross_entropy(raw_logits, labels)

    def predict(self, data, n_samples=10):
        raw_logits = self.forward(data).unsqueeze(dim=-2)
        for i in range(n_samples - 1):
            raw_logits = torch.cat([raw_logits, self.forward(data).unsqueeze(dim=-2)], dim=-2)
        return estimate_dirichlet(raw_logits, self.max_precision)


class CNNDropoutFVIClassifier(CNNDropoutClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z)).mean()
        return fkl


class CNNEnsembleClassifier(nn.Module):
    def __init__(self, n_classes, _, max_precision,
                 block, n_blocks, n_models=5):
        super().__init__()
        self.n_classes = n_classes
        self.max_precision = max_precision
        self.n_models = n_models
        self.models = []
        for _ in range(n_models):
            self.models.append(ResNet(block, n_blocks, num_classes=n_classes))

    def parameters(self, **kwargs):
        params = []
        for model in self.models:
            params += list(model.parameters(**kwargs))
        return params

    def forward(self, input):
        raw_logits = torch.zeros(input.shape[0], self.n_models, self.n_classes, device=input.device)
        for i in range(self.n_models):
            raw_logits[:, i, :] = self.models[i](input)
        return raw_logits

    def nll_loss(self, raw_logits, labels):
        loss = torch.tensor(0., device=raw_logits.device)
        for j in range(self.n_models):
            loss += F.cross_entropy(raw_logits[:, j, :], labels) / self.n_models
        return loss

    def predict(self, data):
        return estimate_dirichlet(self.forward(data), self.max_precision)

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        for model in self.models:
            model.to(*args, **kwargs)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        checkpoint = dict()
        for i in range(self.n_models):
            checkpoint[i] = self.models[i].state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
        return checkpoint

    def load_state_dict(self, state_dict, strict=True):
        for i in range(self.n_models):
            self.models[i].load_state_dict(state_dict[i], strict=strict)

    def train(self, mode=True):
        for model in self.models:
            model.train(mode=mode)

    def eval(self):
        for model in self.models:
            model.eval()


class CNNEnsembleFVIClassifier(CNNEnsembleClassifier):
    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits, self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = torch.tensor(0., device=raw_logits.device)
        for j in range(self.n_models):
            fkl += (var_post.log_prob(z[:, j, :]) - prior.log_prob(z[:, j, :])).mean() / self.n_models
        return fkl


class CNNRadialClassifier(RadialResNet):
    def __init__(self, n_classes, _, max_precision,
                 block, n_blocks):
        self.n_classes = n_classes
        self.max_precision = max_precision
        super().__init__(block, n_blocks, num_classes=n_classes)

    def nll_loss(self, raw_logits, labels):
        nll = F.cross_entropy(raw_logits, labels)
        wkl = self.weight_kl(device=raw_logits.device) / self.max_precision
        return nll + wkl

    def weight_kl(self, w_mu_prior=0., w_sigma_prior=.1, b_mu_prior=0., b_sigma_prior=.1, device='cpu'):
        kl = torch.tensor(0., device=device)
        w_prior = Normal(w_mu_prior, w_sigma_prior)
        b_prior = Normal(b_mu_prior, b_sigma_prior)
        mu = []
        rho = []
        for name, param in self.named_parameters():
            if "conv" in name:
                if "rho" in name:
                    rho.append((name, param))
                else:
                    mu.append((name, param))
        mu.sort(key=lambda x: x[0])
        rho.sort(key=lambda x: x[0])
        assert len(mu) == len(rho)
        for i in range(len(mu)):
            name = mu[i][0]
            assert name == rho[i][0][:-4]
            loc = mu[i][1]
            scale = F.softplus(rho[i][1])
            if "weight" in name:
                prior = w_prior
            elif "bias" in name:
                prior = b_prior
            else:
                raise Exception("Expected 'weight' or 'bias' parameter, received '{}' instead.".format(name))
            kl += torch.distributions.kl_divergence(Normal(loc, scale), prior).sum()
        return kl

    def predict(self, data, n_samples=10):
        raw_logits = self.forward(data).unsqueeze(dim=-2)
        for i in range(n_samples - 1):
            raw_logits = torch.cat([raw_logits, self.forward(data).unsqueeze(dim=-2)], dim=-2)
        return estimate_dirichlet(raw_logits, self.max_precision)


class CNNRadialFVIClassifier(CNNRadialClassifier):
    def nll_loss(self, raw_logits, labels):
        nll = F.cross_entropy(raw_logits, labels)
        return nll

    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z)).mean()
        return fkl


class CNNRank1Classifier(Rank1ResNet):
    def __init__(self, n_classes, _, max_precision,
                 block, n_blocks, n_models):
        self.n_classes = n_classes
        self.max_precision = max_precision
        super().__init__(block, n_blocks, n_models, num_classes=n_classes)

    def nll_loss(self, raw_logits, labels):
        nll = F.cross_entropy(raw_logits, labels)
        kl = self.weight_kl(device=raw_logits.device) / self.max_precision
        return nll + kl

    def weight_kl(self, alpha_mu_prior=1., alpha_sigma_prior=.1,
                  gamma_mu_prior=1., gamma_sigma_prior=.1, weight_decay=0.0001, device='cpu'):
        kl = torch.tensor(0., device=device)
        alpha_prior = Normal(alpha_mu_prior, alpha_sigma_prior)
        gamma_prior = Normal(gamma_mu_prior, gamma_sigma_prior)
        weight = []
        alpha_mu = []
        alpha_rho = []
        gamma_mu = []
        gamma_rho = []
        for name, param in self.named_parameters():
            if "conv" in name:
                if "weight" in name:
                    weight.append((name, param))
                elif "alpha_mu" in name:
                    alpha_mu.append((name, param))
                elif "alpha_rho" in name:
                    alpha_rho.append((name, param))
                elif "gamma_mu" in name:
                    gamma_mu.append((name, param))
                elif "gamma_rho" in name:
                    gamma_rho.append((name, param))
                else:
                    raise Exception("Unexpected model parameter: '{}'".format(name))
        weight.sort(key=lambda x: x[0])
        alpha_mu.sort(key=lambda x: x[0])
        alpha_rho.sort(key=lambda x: x[0])
        gamma_mu.sort(key=lambda x: x[0])
        gamma_rho.sort(key=lambda x: x[0])
        assert len(alpha_mu) == len(alpha_rho) == len(gamma_mu) == len(gamma_rho)
        for i in range(len(alpha_mu)):
            name = weight[i][0][:-6]
            assert name == alpha_mu[i][0][:-8] == alpha_rho[i][0][:-9] == gamma_mu[i][0][:-8] == gamma_rho[i][0][:-9]
            alpha_loc = alpha_mu[i][1]
            alpha_scale = F.softplus(alpha_rho[i][1])
            gamma_loc = gamma_mu[i][1]
            gamma_scale = F.softplus(gamma_rho[i][1])
            kl += torch.distributions.kl_divergence(Normal(alpha_loc, alpha_scale), alpha_prior).sum()
            kl += torch.distributions.kl_divergence(Normal(gamma_loc, gamma_scale), gamma_prior).sum()
            kl += weight_decay * (weight[i][1] ** 2).sum()
        return kl

    def predict(self, data):
        out = self.forward(data.repeat(self.n_models, 1, 1, 1))
        out = out.view(self.n_models, data.shape[0], out.shape[-1]).transpose(0, 1)
        return estimate_dirichlet(out, self.max_precision)


class CNNRank1FVIClassifier(CNNRank1Classifier):
    def nll_loss(self, raw_logits, labels):
        nll = F.cross_entropy(raw_logits, labels)
        return nll

    def fkl_loss(self, raw_logits, prior_param=None):
        z = smooth_softmax(raw_logits)
        with torch.no_grad():
            var_post = Dirichlet(estimate_dirichlet(raw_logits.unsqueeze(dim=-2), self.max_precision))
        if prior_param is not None:
            prior = Dirichlet(torch.tensor(prior_param, device=raw_logits.device))
        else:
            prior = Dirichlet(torch.ones(self.n_classes, device=raw_logits.device))
        fkl = (var_post.log_prob(z) - prior.log_prob(z)).mean()
        return fkl
