# ------------------------------------------------------------------------
# Advancing Out-of-Distribution Detection via Local Neuroplasticity
# Copyright (c) 2024 Alessandro Canevaro. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from TabMedOOD (https://github.com/mazizmalayeri/TabMedOOD/)
# Copyright (c) 2023 Mohammad Azizmalayeri. All Rights Reserved.
# ------------------------------------------------------------------------

from __future__ import annotations

from functools import partial, reduce
from typing import Tuple, Iterable

import torch
import torch.nn as nn
import torch.distributions as dists
from torch.nn.functional import softplus
from torch.distributions import constraints
from torch.distributions.utils import logits_to_probs

import pytorch_lightning as pl

import torch
import torch.distributions as dist
from torch.distributions import constraints
from torch.distributions.utils import probs_to_logits, logits_to_probs, broadcast_all
from torch.distributions.relaxed_categorical import ExpRelaxedCategorical
from torch.distributions.one_hot_categorical import OneHotCategorical
import numpy as np

def get_distribution_by_name(name):
    return {
        'normal': Normal, 'lognormal': LogNormal,
        'bernoulli': Bernoulli, 'poisson': Poisson, 'categorical': Categorical, 'ordinal': Categorical,
        'geometric': Geometric, 'negative_binomial': NegativeBinomial, 'truncated_normal': TruncatedNormal,
        'bernoullitrick': {
            'categorical': CategoricalBernoulliTrick  # , 'ordinal': OrdinalBernoulliTrick
        },
        'gammatrick' : {
            'bernoulli': BernoulliGammaTrick, 'poisson': PoissonGammaTrick,
            'categorical': CategoricalGammaTrick  # , 'ordinal': OrdinalGammaTrick
        }
        }[name]

     
def _get_distributions(num_vars, x_train) -> List[Base]:
    dists = []
    
    '''
    names = []
    num_vars_temp = num_vars//42
    for i in range(num_vars_temp):
        for j in range(7):
            names += ['normal', 'lognormal', 'normal', 'normal', 'normal', 'poisson']
            
    for i in range(num_vars_temp*42, num_vars):
        if len(torch.unique(x_train[:, i])) == 2:
            names.append('bernoulli')
        else:
            names.append('lognormal')
    '''
    names = num_vars*['normal']
        
    for i in range(num_vars):
        dist_i = get_distribution_by_name(names[i])()
        dists += [dist_i]
        
    return dists
    
class ProbabilisticModel(object):
    def __init__(self, num_vars, x_train):
        self.dists = _get_distributions(num_vars, x_train)
        self.indexes = reduce(list.__add__, [[[i, j] for j in range(d.num_dists)] for i, d in enumerate(self.dists)])

    def to(self, device):
        for d in self:
            d._weight = d._weight.to(device)
        return self

    @property
    def weights(self):
        return [d.weight for d in self]

    @weights.setter
    def weights(self, values):
        if isinstance(values, torch.Tensor):
            values = values.detach().tolist()

        for w, d in zip(values, self):
            d.weight = w

    def scale_data(self, x):
        new_x = []
        for i, d in enumerate(self):
            new_x.append(d >> x[:, i])
        return torch.stack(new_x, dim=-1)

    def __rshift__(self, data):
        return self.scale_data(data)

    def params_from_data(self, x, mask):
        params = []
        for i, d in enumerate(self):
            pos = self.gathered_index(i)
            data = x[..., i] if mask is None or mask[..., pos].all() else torch.masked_select(x[..., i], mask[..., pos])
            params += d.params_from_data(data)
        return params

    def preprocess_data(self, x, mask=None):
        new_x = []
        for i, dist_i in enumerate(self.dists):
            new_x += dist_i.preprocess_data(x[:, i], mask)

        for i in range(len(self.dists), x.size(1)):
            new_x += [x[:, i]]

        return torch.stack(new_x, 1)

    def gathered_index(self, index):
        return self.indexes[index][0]

    def __len__(self):
        return len(self.indexes)

    def __getitem__(self, item) -> Base:
        if isinstance(item, int):
            return self.__getitem__(self.indexes[item])

        return self.dists[item[0]][item[1]]

    @property
    def gathered(self):
        class GatherProbabilisticModel(object):
            def __init__(self, model):
                self.model = model

            def __len__(self):
                return len(self.model.dists)

            def __getitem__(self, item):
                offset = sum([d.num_dists for d in self.model.dists[: item]])
                idxs = range(offset, offset + self.model.dists[item].num_dists)

                return idxs, self.model.dists[item]

            @property
            def weights(self):
                return [d.weight for [_, d] in self]

            @weights.setter
            def weights(self, values):
                if isinstance(values, torch.Tensor):
                    values = values.detach().tolist()

                for w, [_, d] in zip(values, self):
                    d.weight = w

            def __iter__(self):
                offset = 0
                for i, d in enumerate(self.model.dists):
                    yield list(range(offset, offset + d.num_dists)), d
                    offset += d.num_dists

            def get_param_names(self):
                names = []
                for i, dist_i in enumerate(self.model.dists):
                    if dist_i.num_dists > 1 or dist_i.size_params[0] > 1:
                        param_name = dist_i.real_parameters[0]
                        num_classes = dist_i.size_params[0] if dist_i.num_dists == 1 else dist_i.num_dists
                        names += [f'{dist_i}_{param_name}{j}_dim{i}' for j in range(num_classes)]
                    else:
                        names += [f'{dist_i}_{v}_dim{i}' for v in dist_i.real_parameters]

                return names

            def scale_data(self, x):
                new_x = []
                for i, [_, d] in enumerate(self):
                    new_x.append(d >> x[:, i])
                return torch.stack(new_x, dim=-1)

            def __rshift__(self, data):
                return self.scale_data(data)

        return GatherProbabilisticModel(self)
        
class Base(object):
    def __init__(self):
        self._weight = torch.tensor([1.0])
        self.arg_constraints = {}
        self.size = 1

    @property
    def weight(self):
        return self._weight

    @weight.setter
    def weight(self, value):
        if not isinstance(value, torch.Tensor) and isinstance(value, Iterable):
            assert len(value) == 1, value
            value = iter(value)

        self._weight = value if isinstance(value, torch.Tensor) else torch.tensor([value])

    @property
    def expanded_weight(self):
        return reduce(list.__add__, [[w] * len(self[i].f) for i,w in enumerate(self.weight)])

    @property
    def parameters(self):
        return list(self.dist.arg_constraints.keys())

    @property
    def real_parameters(self):
        return self.real_dist.real_parameters if id(self) != id(self.real_dist) else self.parameters

    def __getitem__(self, item):
        assert item == 0
        return self

    def preprocess_data(self, x, mask=None):
        return x,

    def scale_data(self, x, weight=None):
        weight = weight or self.weight
        return x * weight

    def unscale_data(self, x, weight=None):
        weight = weight or self.weight
        return x / weight

    @property
    def f(self):
        raise NotImplementedError()

    def sample(self, size, etas):
        real_params = self.to_real_params(etas)
        real_params = dict(zip(self.real_parameters, real_params))
        return self.real_dist.dist(**real_params).sample(torch.Size([size]))

    def impute(self, etas):
        raise NotImplementedError()
        # real_params = self.to_real_params(etas)
        # real_params = dict(zip(self.real_parameters, real_params))
        # return self.real_dist.dist(**real_params).mean

    def mean(self, etas):
        params = self.to_params(etas)
        params = dict(zip(self.parameters, params))
        return self.dist(**params).mean

    def to_text(self, etas):
        params = self.to_real_params(etas)
        params = [x.cpu().tolist() for x in params]
        params = dict(zip(self.real_parameters, params))
        try:
            mean = self.mean(etas).item()
        except NotImplementedError:
            mean = None

        return f'{self.real_dist} params={params}' + (f' mean={mean}' if mean is not None else '')

    def params_from_data(self, x):
        raise NotImplementedError()

    def real_params_from_data(self, x):
        etas = self.real_dist.params_from_data(x)
        return self.real_dist.to_real_params(etas)

    @property
    def real_dist(self) -> Base:
        return self

    def to_real_params(self, etas):
        return self.to_params(etas)

    @property
    def num_params(self):
        return len(self.arg_constraints)

    @property
    def size_params(self):
        return [1] * self.num_params

    @property
    def num_suff_stats(self):
        return self.num_params

    @property
    def num_dists(self):
        return 1

    def log_prob(self, x, etas):
        params = self.to_params(etas)
        params = dict(zip(self.parameters, params))
        return self.dist(**params).log_prob(x)

    def real_log_prob(self, x, etas):
        real_params = self.to_real_params(etas)
        real_params = dict(zip(self.real_parameters, real_params))
        return self.real_dist.dist(**real_params).log_prob(x)

    @property
    def dist(self):
        raise NotImplementedError()

    def unscale_params(self, etas):
        c = torch.ones_like(etas)
        for i, f in enumerate(self.f):
            c[i].mul_(f(self.expanded_weight[i]).item())
        return etas * c

    def scale_params(self, etas):
        c = torch.ones_like(etas)
        for i, f in enumerate(self.f):
            c[i].mul_(f(self.expanded_weight[i]).item())
        return etas / c

    def __str__(self):
        raise NotImplementedError()

    def to_params(self, etas):
        raise NotImplementedError()

    def to_naturals(self, params):
        raise NotImplementedError()

    @property
    def is_discrete(self):
        raise NotImplementedError()

    @property
    def is_continuous(self):
        return not self.is_discrete

    def __rshift__(self, data):
        return self.scale_data(data)

    def __lshift__(self, etas):
        return self.unscale_params(etas)


class Normal(Base):
    def __init__(self):
        super(Normal, self).__init__()

        self.arg_constraints = [
            constraints.real,  # eta1
            constraints.less_than(0)  # eta2
        ]

    @property
    def is_discrete(self):
        return False

    @property
    def dist(self):
        return dist.Normal

    @property
    def f(self):
        return [lambda w: w, lambda w: w**2]

    def params_from_data(self, x):
        return self.to_naturals([x.mean(), x.std()])

    def to_params(self, etas):
        eta1, eta2 = etas
        return -0.5 * eta1 / eta2, torch.sqrt(-0.5 / eta2)

    def to_naturals(self, params):
        loc, std = params

        eta2 = -0.5 / std ** 2
        eta1 = -2 * loc * eta2

        return eta1, eta2

    def impute(self, etas):
        return self.mean(etas)

    def __str__(self):
        return 'normal'
        
class LogNormal(Normal):
    def scale_data(self, x, weight=None):
        weight = self.weight if weight is None else weight
        return torch.clamp(torch.pow(x, weight), min=1e-20, max=1e20)

    def unscale_data(self, x, weight=None):
        weight = self.weight if weight is None else weight
        return torch.clamp(torch.pow(x, 1./weight), min=1e-20, max=1e20)

    @property
    def dist(self):
        return dist.LogNormal

    def params_from_data(self, x):
        return super().params_from_data(torch.log(x))

    def sample(self, size, etas):
        return torch.clamp(super().sample(size, etas), min=1e-20, max=1e20)

    def impute(self, etas):
        mu, sigma = self.to_real_params(etas)
        return torch.clamp(torch.exp(mu - sigma**2), min=1e-20, max=1e20)

    def __str__(self):
        return 'lognormal'


class TruncatedNormal(Normal):
    @property
    def is_discrete(self):
        return True

    def log_prob(self, x, etas):
        params = self.to_params(etas)
        params = dict(zip(self.parameters, params))
        dist = self.dist(**params)

        log_prob = torch.log(torch.clamp(dist.cdf(x+1) - dist.cdf(x), min=1e-20))  # Slicing
        mask = x <= 0
        log_prob[mask] = dist.cdf(torch.zeros_like(log_prob))[mask].clamp_min(1e-20).log()

        return log_prob

    def real_log_prob(self, x, etas):
        real_params = self.to_real_params(etas)
        real_params = dict(zip(self.real_parameters, real_params))
        dist = self.real_dist.dist(**real_params)

        log_prob = torch.log(torch.clamp(dist.cdf(x+1) - dist.cdf(x), min=1e-20))  # Slicing
        mask = x <= 0
        log_prob[mask] = dist.cdf(torch.zeros_like(log_prob))[mask].clamp_min(1e-20).log()

        return log_prob

    def impute(self, etas):
        return self.mean(etas).floor().clamp_min(0)

    def __str__(self):
        return 'truncated_normal'


class LogNormal(Normal):
    def scale_data(self, x, weight=None):
        weight = self.weight if weight is None else weight
        return torch.clamp(torch.pow(x, weight), min=1e-20, max=1e20)

    def unscale_data(self, x, weight=None):
        weight = self.weight if weight is None else weight
        return torch.clamp(torch.pow(x, 1./weight), min=1e-20, max=1e20)

    @property
    def dist(self):
        return dist.LogNormal

    def params_from_data(self, x):
        return super().params_from_data(torch.log(x))

    def sample(self, size, etas):
        return torch.clamp(super().sample(size, etas), min=1e-20, max=1e20)

    def impute(self, etas):
        mu, sigma = self.to_real_params(etas)
        return torch.clamp(torch.exp(mu - sigma**2), min=1e-20, max=1e20)

    def __str__(self):
        return 'lognormal'


class Gamma(Base):
    def __init__(self):
        super().__init__()

        self.arg_constraints = [
            constraints.greater_than(-1),  # eta1
            constraints.less_than(0)  # eta2
        ]

    @property
    def dist(self):
        return dist.Gamma

    @property
    def f(self):
        return [lambda w: torch.ones_like(w), lambda w: w]

    @property
    def is_discrete(self):
        return False

    def params_from_data(self, x):
        mean, meanlog = x.mean(), x.log().mean()
        s = mean.log() - meanlog

        shape = (3 - s + ((s-3)**2 + 24*s).sqrt()) / (12 * s)
        for _ in range(50):
            shape = shape - (shape.log() - torch.digamma(shape) - s) / (1 / shape - torch.polygamma(1, shape))

        concentration = shape
        rate = shape / mean

        eta1 = concentration - 1
        eta2 = -rate

        return eta1, eta2

    def to_params(self, etas):
        eta1, eta2 = etas

        return eta1 + 1, -eta2

    def impute(self, etas):
        alpha, beta = self.to_real_params(etas)
        return torch.clamp((alpha - 1) / beta, min=0.0)

    def __str__(self):
        return 'gamma'


class Exponential(Base):
    def __init__(self):
        super(Exponential, self).__init__()

        self.arg_constraints = [
            constraints.less_than(0)  # eta1
        ]

    @property
    def dist(self):
        return dist.Exponential

    @property
    def is_discrete(self):
        return False

    @property
    def f(self):
        return [lambda w: w]

    def params_from_data(self, x):
        mean = x.mean()
        return -1 / mean,

    def to_params(self, etas):
        return -etas[0],

    def impute(self, etas):
        raise NotImplementedError()

    def __str__(self):
        return "exponential"


class Bernoulli(Base):
    def __init__(self):
        super().__init__()
        self.size = 2
        self.arg_constraints = [
            constraints.real
        ]

    @property
    def dist(self):
        return dist.Bernoulli

    @property
    def is_discrete(self):
        return True

    @property
    def parameters(self):
        return 'logits',

    @property
    def real_parameters(self):
        return 'probs',

    def scale_data(self, x, weight=None):
        return x

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)]

    def params_from_data(self, x):
        return probs_to_logits(x.mean(), is_binary=True),

    def to_params(self, etas):
        return etas[0],

    def to_real_params(self, etas):
        return logits_to_probs(self.to_params(etas)[0], is_binary=True),

    def impute(self, etas):
        probs = self.to_real_params(etas)[0]

        return (probs >= 0.5).float()

    def __str__(self):
        return 'bernoulli'


class Poisson(Base):
    def __init__(self):
        super().__init__()

        self.arg_constraints = [
            constraints.real
        ]

    @property
    def dist(self):
        return dist.Poisson

    @property
    def is_discrete(self):
        return True

    def scale_data(self, x, weight=None):
        return x

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)]

    def params_from_data(self, x):
        return torch.log(torch.clamp(x.mean(), min=1e-20)),

    def to_params(self, etas):
        return torch.exp(etas[0]).clamp(min=1e-6, max=1e20),  # TODO

    def impute(self, etas):
        rate = self.to_real_params(etas)[0]
        return rate.floor()

    def __str__(self):
        return 'poisson'


class Geometric(Base):
    def __init__(self):
        super().__init__()

        self.arg_constraints = [
            constraints.real
        ]

    @property
    def dist(self):
        return dist.Geometric

    @property
    def is_discrete(self):
        return True

    @property
    def parameters(self):
        return 'logits',

    @property
    def real_parameters(self):
        return 'probs',

    def scale_data(self, x, weight=None):
        return x

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)]

    def params_from_data(self, x):
        return probs_to_logits(torch.reciprocal(x.mean() + 1), is_binary=True),

    def to_params(self, etas):
        return etas[0],

    def to_real_params(self, etas):
        return logits_to_probs(self.to_params(etas)[0], is_binary=True),

    def impute(self, etas):
        return torch.floor(self.mean(etas))

    def __str__(self):
        return 'geometric'


class NegativeBinomial(Base):
    def __init__(self):
        super().__init__()

        self.arg_constraints = [
            constraints.greater_than(0),
            constraints.real
        ]

    @property
    def dist(self):
        return dist.NegativeBinomial

    @property
    def is_discrete(self):
        return True

    @property
    def parameters(self):
        return 'total_count', 'logits'

    @property
    def real_parameters(self):
        return 'total_count', 'probs'

    def scale_data(self, x, weight=None):
        return x

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)] * 2

    def params_from_data(self, x):
        m, s = x.mean(), x.std()
        probs = torch.clamp(1. - m / s**2, 0, 1)
        logits = probs_to_logits(probs, is_binary=True)
        total_count = torch.clamp(m * (1 - probs) / probs, 1e-20)
        return total_count, logits

    def to_params(self, etas):
        return etas

    def to_real_params(self, etas):
        return etas[0], logits_to_probs(self.to_params(etas)[1], is_binary=True),

    def impute(self, etas):
        total_count, probs = self.to_real_params(etas)
        mode = torch.floor(probs * (total_count - 1) / (1 - probs))
        mask = total_count <= 1.
        mode = torch.masked_fill(mode, mask, 0.)
        return mode

    def __str__(self):
        return 'negative binomial'


class TruncatedNormalDistribution(dist.Distribution):
    arg_constraints = {'minimum_value': constraints.real,
                       'maximum_value': constraints.real,
                       'loc': constraints.real,
                       'scale': constraints.positive}  # See https://en.wikipedia.org/wiki/Truncated_normal_distribution
    support = constraints.real
    has_rsample = True

    def __init__(self, minimum_value: torch.Tensor, maximum_value: torch.Tensor, loc: torch.Tensor,
                 scale: torch.Tensor, validate_args: bool = True):
        self.minimum_value = minimum_value
        self.maximum_value = maximum_value
        self.loc = loc
        self.scale = scale

        self.minimum_value, self.maximum_value, self.loc, self.scale = broadcast_all(
            self.minimum_value, self.maximum_value, self.loc, self.scale
        )
        assert torch.all(self.minimum_value <= self.maximum_value)

        self._normal = dist.Normal(torch.zeros_like(loc), 1.)

        self.alpha = (self.minimum_value - self.loc) / self.scale
        self.beta = (self.maximum_value - self.loc) / self.scale
        self.normalizing_constant = self._normal.cdf(self.beta) - self._normal.cdf(self.alpha)
        self.normalizing_constant = torch.clamp_min(self.normalizing_constant, 1e-20)

        batch_shape = self.loc.size()
        super(TruncatedNormalDistribution, self).__init__(batch_shape, validate_args=validate_args)

    @property
    def mean(self):
        extra_loc = (self._normal.log_prob(self.alpha).exp() - self._normal.log_prob(self.beta).exp())
        extra_loc = extra_loc / self.normalizing_constant * self.scale
        return self.loc + extra_loc

    @property
    def param_shape(self):
        return self.loc.size()

    @property
    def variance(self):
        # TODO Not needed
        raise NotImplementedError()

    def cdf(self, value):
        return (self._normal.cdf((value - self.loc) / self.scale) - self._normal.cdf(self.alpha)) \
               / self.normalizing_constant

    def sample(self, sample_shape=torch.Size()):
        samples = dist.Uniform(torch.zeros_like(self.loc), torch.ones_like(self.loc)).sample(sample_shape)
        cdf_alpha, cdf_beta = self._normal.cdf(self.alpha), self._normal.cdf(self.beta)
        samples = self._normal.icdf(cdf_alpha + samples * (cdf_beta - cdf_alpha)) * self.scale + self.loc
        return samples

    def log_prob(self, value):
        assert torch.all(self.minimum_value <= value) and torch.all(value <= self.maximum_value)
        if self._validate_args:
            self._validate_sample(value)

        log_prob = self._normal.log_prob((value - self.loc) / self.scale) - self.scale.log()
        log_prob -= self.normalizing_constant.log()
        return log_prob


class TruncatedNormal2(Base):
    def __init__(self):
        super().__init__()

        self.arg_constraints = [
            # constraints.real,  # min
            # constraints.real,  # max
            constraints.real,  # loc
            constraints.positive  # scale
        ]

    @property
    def dist(self):
        return TruncatedNormalDistribution

    @property
    def is_discrete(self):
        return True

    @property
    def parameters(self):
        return 'minimum_value', 'maximum_value', 'loc', 'scale'

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)] * 2

    def params_from_data(self, x):
        raise NotImplementedError

    def to_params(self, etas):
        return 0., 1., etas[0], etas[1]

    def impute(self, etas):
        mode = self.mean(etas)

        mask1 = etas[0] < 0.
        mode = torch.masked_fill(mode, mask1, 0.)

        mask2 = etas[0] > 1.
        mode = torch.masked_fill(mode, mask2, 1.)

        # return torch.floor(mode)
        return mode

    def __str__(self):
        return 'truncated_normal'


class BernoulliGammaTrick(Gamma):
    def __init__(self):
        super().__init__()
        self.noise_dist = dist.Beta(1.1, 30)

    @property
    def real_dist(self) -> Base:
        return Bernoulli()

    def preprocess_data(self, x, mask=None):
        x = super(BernoulliGammaTrick, self).preprocess_data(x)[0]
        noise = self.noise_dist.sample([x.size(0)])

        return x + 1 + noise,
        # return x + noise,

    def mean(self, etas):
        return torch.clamp(super().mean(etas) - 1 - self.noise_dist.mean, min=0., max=1.)
        # return torch.clamp(super().mean(etas) - self.noise_dist.mean, min=0., max=1.)

    def to_real_params(self, etas):
        return self.mean(etas),

    def impute(self, etas):
        probs = self.to_real_params(etas)[0]
        return (probs >= 0.5).float()

    def __str__(self):
        return f'{self.real_dist}*'


class PoissonGammaTrick(Gamma):
    def __init__(self):
        super().__init__()
        self.noise_dist = dist.Beta(1.1, 30)

    @property
    def real_dist(self) -> Base:
        return Poisson()

    def preprocess_data(self, x, mask=None):
        x = super().preprocess_data(x)[0]
        noise = self.noise_dist.sample([x.size(0)])

        return x + 1 + noise,

    def mean(self, etas):
        return super().mean(etas) - 1 - self.noise_dist.mean

    def to_real_params(self, etas):
        return torch.clamp(self.mean(etas), min=1e-10),  # rate > 0

    def impute(self, etas):
        rate = self.to_real_params(etas)[0]
        return rate.floor()

    def __str__(self):
        return f'{self.real_dist}*'


class Categorical(Base):
    def __init__(self, size):
        super().__init__()
        self.arg_constraints = [constraints.real_vector]
        self.size = size

    @property
    def dist(self):
        return dist.Categorical

    @property
    def parameters(self):
        return 'logits',

    @property
    def is_discrete(self):
        return True

    @property
    def real_parameters(self):
        return 'probs',

    @property
    def size_params(self):
        return [self.size]

    def scale_data(self, x, weight=None):
        return x

    @property
    def f(self):
        return [lambda w: torch.ones_like(w)]

    def impute(self, etas):
        real_params = self.to_real_params(etas)
        real_params = dict(zip(self.real_parameters, real_params))
        return self.real_dist.dist(**real_params).probs.max(dim=-1)[1]

    def params_from_data(self, x):
        new_x = to_one_hot(x, self.size)
        return probs_to_logits(new_x.sum(dim=0) / x.size(0)),

    def mean(self, etas):
        raise NotImplementedError()

    def to_params(self, etas):
        return etas[0],

    def to_real_params(self, etas):
        return logits_to_probs(self.to_params(etas)[0]),

    def __str__(self):
        return f'categorical({self.size})'


class CategoricalBernoulliTrick(Base):
    def __init__(self, size):
        super().__init__()
        del self._weight

        self.dists = [Bernoulli() for _ in range(size)]
        self.arg_constraints = reduce(list.__add__, [d.arg_constraints for d in self.dists])
        self.size = size

    @property
    def dist(self):
        return dist.Categorical

    @property
    def is_discrete(self):
        return True

    @property
    def real_parameters(self):
        return 'probs',

    @property
    def weight(self):
        return torch.tensor([d.weight for d in self.dists])

    @weight.setter
    def weight(self, value):
        assert self.num_dists == len(value)

        for d, v in zip(self.dists, value):
            d.weight = v

    def params_from_data(self, x):
        raise NotImplementedError()

    def to_params(self, etas):
        raise NotImplementedError()

    def real_params_from_data(self, x):
        return Categorical(self.size).real_params_from_data(x)

    @property
    def num_dists(self):
        return len(self.dists)

    @property
    def num_params(self):
        return sum([d.num_params for d in self.dists])

    def impute(self, etas):
        real_params = self.to_real_params(etas)
        real_params = dict(zip(self.real_parameters, real_params))
        return self.real_dist.dist(**real_params).probs.max(dim=-1)[1]

    def real_log_prob(self, x, etas):
        params = self.to_real_params(etas)
        params = dict(zip(self.real_parameters, params))

        new_x = x if len(x.size()) > 2 else to_one_hot(x, self.size)
        return dist.OneHotCategorical(**params).log_prob(new_x)

    def __getitem__(self, item):
        return self.dists[item]

    def preprocess_data(self, x, mask=None):
        new_x = super().preprocess_data(x)[0]
        x_one_hot = to_one_hot(new_x, self.size)

        new_x = []
        for i in range(self.size):
            new_x += self.dists[i].preprocess_data(x_one_hot[..., i])
        return new_x

    def scale_data(self, x, weight=None):
        return x

    def mean(self, etas):
        raise NotImplementedError()

    @property
    def f(self):
        return [lambda w: torch.tensor([1.0])]

    def to_real_params(self, etas):
        pos, probs = 0, []

        for i, d in enumerate(self.dists):
            probs.append(d.to_real_params(etas[pos: pos + d.num_params])[0])  # .detach())
            pos += d.num_params

        probs = torch.stack(probs, dim=-1)
        probs = torch.clamp(probs, min=1e-20)
        probs = probs / probs.sum(dim=-1, keepdim=True)
        return probs,

    def __str__(self):
        return f'categorical({self.size})+'


class CategoricalGammaTrick(CategoricalBernoulliTrick):
    def __init__(self, size):
        super(CategoricalGammaTrick, self).__init__(size)
        self.dists = [BernoulliGammaTrick() for _ in range(size)]

    @property
    def real_dist(self) -> Base:
        return CategoricalBernoulliTrick(self.size)

    @property
    def f(self):
        return reduce(list.__add__, [d.f for d in self.dists])

    @property
    def is_discrete(self):
        return False

    def __str__(self):
        return f'categorical({self.size})*'
        
class GumbelDistribution(ExpRelaxedCategorical):
    @torch.no_grad()
    def sample(self, sample_shape=torch.Size()):
        return OneHotCategorical(probs=self.probs).sample(sample_shape)

    def rsample(self, sample_shape=torch.Size()):
        return torch.exp(super().rsample(sample_shape))

    @property
    def mean(self):
        return self.probs

    def expand(self, batch_shape, _instance=None):
        return super().expand(batch_shape[:-1], _instance)

    def log_prob(self, value):
        return OneHotCategorical(probs=self.probs).log_prob(value)
        
        
def to_one_hot(x, size):
    x_one_hot = x.new_zeros(x.size(0), size)
    x_one_hot.scatter_(1, x.unsqueeze(-1).long(), 1).float()
    return x_one_hot
    
    
def init_weights(m, gain=1.):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.05)
        if m.bias is not None:
            m.bias.data.fill_(0.01)


class Encoder(nn.Module):
    def __init__(self, prob_model, size_s, size_z, input_size):
        super().__init__()
        #input_size = sum([d.size for d in prob_model])

        # Encoder
        self.encoder_s = nn.Linear(input_size, size_s)

        self.encoder_z = nn.Identity()  # Just in case we want to increase this part
        self.q_z_loc = nn.Linear(input_size + size_s, size_z)
        self.q_z_log_scale = nn.Linear(input_size + size_s, size_z)

        self.encoder_z.apply(partial(init_weights))
        self.q_z_loc.apply(init_weights)
        self.q_z_log_scale.apply(init_weights)

        self.temperature = 1.

    def q_z(self, loc, log_scale):
        scale = torch.exp(log_scale)
        scale = torch.clamp(scale, min=1e-20, max=1e20)
        return dists.Normal(loc, scale)

    def q_s(self, logits):
        return GumbelDistribution(logits=logits, temperature=self.temperature)

    def forward(self, x, mode=False):
        s_logits = self.encoder_s(x)
        if mode:
            s_samples = to_one_hot(torch.argmax(s_logits, dim=-1), s_logits.size(-1)).float()
        else:
            try:
                s_samples = self.q_s(s_logits).rsample() if self.training else self.q_s(s_logits).sample()
            except:
                s_logits = torch.clamp(s_logits, min=1e-20)
                s_logits = s_logits / torch.sum(s_logits)
                s_samples = self.q_s(s_logits).rsample() if self.training else self.q_s(s_logits).sample()

        x_and_s = torch.cat((x, s_samples), dim=-1)  # batch_size x (input_size + latent_s_size)

        h = self.encoder_z(x_and_s)
        z_loc = self.q_z_loc(h)
        z_log_scale = self.q_z_log_scale(h)
        # z_log_scale = torch.clamp(z_log_scale, -7.5, 7.5)

        return s_samples, [s_logits, z_loc, z_log_scale]


class HIVAEHead(nn.Module):
    def __init__(self, dist, size_s, size_z, size_y):
        super().__init__()
        self.dist = dist

        # Generates its own y from z
        self.net_y = nn.Linear(size_z, size_y)

        # First parameter generated with y and s
        self.head_y_and_s = nn.Linear(size_y + size_s, self.dist.size_params[0], bias=False)

        # Next parameters (if any) generated only with s
        self.head_s = None
        if len(self.dist.size_params) > 1:
            self.head_s = nn.Linear(size_s, sum(self.dist.size_params[1:]), bias=False)
            self.head_s.apply(partial(init_weights))

        self.net_y.apply(partial(init_weights))
        self.head_y_and_s.apply(partial(init_weights))

    def unpack_params(self, theta, first_parameter):
        noise = 1e-15

        params = []
        pos = 0
        for i in ([0] if first_parameter else range(1, self.dist.num_params)):
            value = theta[..., pos: pos + self.dist.size_params[i]]
            value = value.squeeze(-1)

            if isinstance(self.dist.arg_constraints[i], constraints.greater_than):
                lower_bound = self.dist.arg_constraints[i].lower_bound
                value = lower_bound + noise + softplus(value)

            elif isinstance(self.dist.arg_constraints[i], constraints.less_than):
                upper_bound = self.dist.arg_constraints[i].upper_bound
                value = upper_bound - noise - softplus(value)

            elif self.dist.arg_constraints[i] == constraints.simplex:
                value = logits_to_probs(value)

            elif self.dist.size > 1:
                value[..., 0] = value[..., 0] * 0.

            params += [value]
            pos += self.dist.size_params[i]

        return torch.stack(params, dim=0)

    def forward(self, z, s):
        y = self.net_y(z)
        y_and_s = torch.cat((y, s), dim=-1)  # batch_size x (hidden_size + latent_s_size)

        raw_params = self.head_y_and_s(y_and_s)  # First parameter
        params = self.unpack_params(raw_params, first_parameter=True)

        if self.head_s is not None:  # Other parameters (if any)
            raw_params = self.head_s(s)
            params_s = self.unpack_params(raw_params, first_parameter=False)
            params = torch.cat((params, params_s), dim=0)

        return params

class MOOForLoop(nn.Module):

    inputs: Optional[torch.Tensor]

    def __init__(self, num_heads: int, moo_method: Optional[nn.Module] = None):
        super().__init__()

        # The user should explicitly set up the optimizers for the algorithms that have learnable parameters
        self._moo_method = [moo_method]
        self.num_heads = num_heads
        self.inputs = None
        self.outputs = None

        if self.moo_method is not None:
            self.register_full_backward_hook(MOOForLoop._hook)

    @property
    def moo_method(self):
        return self._moo_method[0]

    # @typechecked
    def _hook(self, grads_input: Tuple[torch.Tensor], grads_output: Any) -> Tuple[torch.Tensor]:
        moo_directions = self.moo_method(
            grads_output[0], self.inputs, self.outputs)
        self.outputs = None

        # we scale the gradients so that they have the same magnitude as the unmodified ones
        original_norm = grads_output[0].sum(dim=0).norm(p=2)
        moo_norm = moo_directions.sum(dim=0).norm(p=2).clamp_min(1e-10)
        moo_directions.mul_(original_norm / moo_norm)

        return moo_directions.sum(dim=0),

    def forward(self, z):
        extended_shape = [self.num_heads] + [-1 for _ in range(z.ndim)]
        #if self.moo_method.requires_input and z.requires_grad:
        #    self.inputs = z.detach()
        extended_z = z.unsqueeze(0).expand(extended_shape)
        return extended_z


class MultiMOOForLoop(nn.Module):
    def __init__(self, num_heads: int, moo_methods: Sequence[Optional[nn.Module]] = None):
        super().__init__()

        self.num_inputs = len(moo_methods)
        self.loops = [MOOForLoop(num_heads, method) for method in moo_methods]

    def forward(self, *args):
        assert len(args) == self.num_inputs
        return (loop(z) for z, loop in zip(args, self.loops))
        
class HIVAE(pl.LightningModule):
    def __init__(self, x_train):
        super().__init__()
        
        num_vars = x_train.shape[1]
        prob_model = ProbabilisticModel(num_vars, x_train)
        self.prob_model = prob_model
        self.samples = 1

        # Parameters for the normalization layers
        self.mean_data = [0. for _ in range(len(prob_model))]
        self.std_data = [1. for _ in range(len(prob_model))]
        latent_size = max(1, int(len(prob_model.gathered) * 0.75 + 0.5))

        # Priors
        self.prior_s_pi = torch.ones(latent_size) / latent_size
        self.p_z_loc = nn.Linear(latent_size, latent_size)
        self.p_z_loc.apply(partial(init_weights))

        # Encoder
        self.encoder = Encoder(prob_model, latent_size, latent_size, num_vars)

        # Decoder
        self.decoder_shared = nn.Identity()  # In case we want to increase this part
        self.decoder_shared.apply(partial(init_weights))

        self.heads = nn.ModuleList([
            HIVAEHead(dist, latent_size, latent_size, latent_size) for dist in prob_model
        ])

        # MOO
        self._mtl_module_y: Tuple[nn.Module] = (self.setup_moo(),)
        self._mtl_module_s: Tuple[nn.Module] = (self.setup_moo(),)
        self.moo_block = MultiMOOForLoop(len(self.heads), moo_methods=(self.mtl_module_y, self.mtl_module_s))

    def setup_moo(self) -> nn.Module:
        class Identity(nn.Module):
            def __init__(self):
                super().__init__()
                
            def forward(self, grads: torch.Tensor, inputs: Optional[torch.Tensor],
                        outputs: Optional[torch.Tensor]) -> torch.Tensor:
                return grads

        return Identity()

        num_tasks = len(self.heads)

        modules = []
        methods = ['pcgrad']
        for method in methods:
            if method == 'pcgrad':
                modules.append(moo.PCGrad())
            elif method == 'nsgd':
                modules.append(moo.NSGD(num_tasks, 20))
            elif method == 'mgda':
                modules.append(moo.MGDAUB())
            elif method == 'imtl-g':
                modules.append(moo.IMTLG())
            elif method == 'graddrop':
                modules.append(moo.GradDrop())
            #elif method == 'gradnorm':
            #    modules.append(moo.GradNormModified(num_tasks, hparams.alpha, 20))

        return moo.Compose(*modules) if len(modules) != 0 else None

    @property
    def mtl_module_y(self):
        return self._mtl_module_y[0]

    @property
    def mtl_module_s(self):
        return self._mtl_module_s[0]

    def prior_z(self, loc):
        return dists.Normal(loc, 1.)

    @property
    def prior_s(self):
        return dists.OneHotCategorical(probs=self.prior_s_pi, validate_args=False)

    def normalize_data(self, x, mask, epsilon=1e-6):
        assert len(self.prob_model) == x.size(-1)

        new_x = []
        for i, d in enumerate(self.prob_model):
            x_i = torch.masked_select(x[..., i], mask[..., i].bool()) if mask is not None else x[..., i]
            new_x_i = torch.unsqueeze(x[..., i], 1)

            if str(d) == 'normal':
                self.mean_data[i] = x_i.mean()
                self.std_data[i] = x_i.std()
                self.std_data[i] = torch.clamp(self.std_data[i], 1e-6, 1e20)

                new_x_i = (new_x_i - self.mean_data[i]) / (self.std_data[i] + epsilon)
            elif str(d) == 'lognormal':
                x_i = torch.log1p(x_i)
                self.mean_data[i] = x_i.mean()
                self.std_data[i] = x_i.std()
                self.std_data[i] = torch.clamp(self.std_data[i], 1e-10, 1e20)

                new_x_i = (torch.log1p(new_x_i) - self.mean_data[i]) / (self.std_data[i] + epsilon)
            elif str(d) == 'poisson':
                new_x_i = torch.log1p(new_x_i)  # x[..., i] can have 0 values (just as a poisson distribution)

            #elif 'categorical' in str(d) or 'bernoulli' in str(d):
            #    new_x_i = to_one_hot(torch.squeeze(new_x_i, 1), d.size)


            new_x.append(new_x_i)

        # new_x = torch.stack(new_x, dim=-1)
        new_x = torch.cat(new_x, 1)

        def broadcast_mask(mask, prob_model):
            if all([d.size == 1 for d in prob_model]):
                return mask

            new_mask = []
            for i, d in enumerate(self.prob_model):
                new_mask.append(mask[:, i].unsqueeze(-1).expand(-1, d.size))

            return torch.cat(new_mask, dim=-1)

        mask = broadcast_mask(mask, self.prob_model)

        if mask is not None:
            new_x = new_x * mask
        return new_x

    def denormalize_params(self, etas):
        new_etas = []
        for i, d in enumerate(self.prob_model):
            etas_i = etas[i]

            if str(d) == 'normal':
                mean_data, std_data = self.mean_data[i], self.std_data[i]
                std_data = torch.clamp(std_data, min=1e-3)

                mean, std = d.to_params(etas_i)
                mean = mean * std_data + mean_data
                std = torch.clamp(std, min=1e-3, max=1e20)
                std = std * std_data

                etas_i = d.to_naturals([mean, std])
                etas_i = torch.stack(etas_i, dim=0)
            elif str(d) == 'lognormal':
                mean_data, std_data = self.mean_data[i], self.std_data[i]
                # std_data = torch.clamp(std_data, min=1e-10)

                mean, std = d.to_params(etas_i)
                mean = mean * std_data + mean_data
                # std = torch.clamp(std, min=1e-6) #, max=1)
                std = std * std_data

                etas_i = d.to_naturals([mean, std])
                etas_i = torch.stack(etas_i, dim=0)

            new_etas.append(etas_i)
        return new_etas

    def _run_step(self, x, mask):
        # Normalization layer
        new_x = self.normalize_data(x, mask)

        # Sampling s and obtaining z and s parameters
        s_samples, params = self.encoder(new_x)
        s_logits, z_loc, z_log_scale = params

        # Sampling z
        z = self.encoder.q_z(z_loc, z_log_scale).rsample()

        # Obtaining the parameters of x
        y_shared = self.decoder_shared(z)
        x_params = [
            head(y_i, s_i) for head, y_i, s_i in zip(self.heads, *self.moo_block(y_shared, s_samples))
        ]
        # x_params = [
        #     head(y_shared, s_samples) for head in self.heads
        # ]
        x_params = self.denormalize_params(x_params)  # Denormalizing parameters

        # Compute all the log-likelihoods

        # batch_size x D
        log_px_z = [self.log_likelihood(x, mask, i, params_i) for i, params_i in enumerate(x_params)]

        pz_loc = self.p_z_loc(s_samples)
        log_pz = self.prior_z(pz_loc).log_prob(z).sum(dim=-1)  # batch_size
        log_qz_x = self.encoder.q_z(z_loc, z_log_scale).log_prob(z).sum(dim=-1)  # batch_size
        kl_z = log_qz_x - log_pz

        # batch_size
        log_ps = self.prior_s.log_prob(s_samples)
        log_qs_x = dists.OneHotCategorical(logits=s_logits, validate_args=False).log_prob(s_samples)
        kl_s = log_qs_x - log_ps

        return log_px_z, kl_z, kl_s

    def _step(self, batch, batch_idx):
        x = batch
        mask = torch.ones_like(x)
        log_px_z, kl_z, kl_s = self._run_step(x, mask)

        elbo = sum(log_px_z) - kl_z - kl_s
        loss = -elbo.sum(dim=0)
        assert loss.size() == torch.Size([])

        logs = dict()
        logs['loss'] = loss / x.size(0)

        with torch.no_grad():
            log_prob = (self.log_likelihood_real(x, mask) * mask).sum(dim=0) / mask.sum(dim=0)
            logs['re'] = -log_prob.mean(dim=0)
            logs['kl_z'] = kl_z.mean(dim=0)
            logs['kl_s'] = kl_s.mean(dim=0)
            logs.update({f'll_{i}': l_i.item() for i, l_i in enumerate(log_prob)})

            if self.training:
                logs['temperature'] = self.encoder.temperature

        return loss, logs

    def training_step(self, batch, batch_idx):
        self.encoder.temperature = max(1e-3, 1. - 0.01 * self.trainer.current_epoch)
        loss, logs = self._step(batch, batch_idx)
        self.log_dict({f'training/{k}': v for k, v in logs.items()})
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logs = self._step(batch, batch_idx)
        self.log_dict({f'validation/{k}': v for k, v in logs.items()})
        return loss

    def _infer_step(self, x, mask, mode):
        new_x = self.normalize_data(x, mask)
        s_samples, params = self.encoder(new_x, mode=mode)
        s_logits, z_loc, z_log_scale = params

        if mode:
            z = z_loc  # Mode of a Normal distribution
        else:
            z = self.encoder.q_z(z_loc, z_log_scale).sample()

        y_shared = self.decoder_shared(z)
        x_params = [head(y_shared, s_samples) for head in self.heads]
        x_params = self.denormalize_params(x_params)  # Denormalizing parameters

        return x_params

    def _impute_step(self, x, mask, mode):
        x_params = self._infer_step(x, mask, mode=mode)

        new_x = []
        for idxs, dist_i in self.prob_model.gathered:
            params = torch.cat([x_params[i] for i in idxs], dim=0)
            new_x_i = dist_i.impute(params).float().flatten()
            if str(dist_i) == 'lognormal':
                # new_x_i = torch.where(new_x_i > 20, new_x_i, new_x_i.expm1().log())
                new_x_i = torch.clamp(new_x_i, 1e-20, 1e20)
            new_x.append(new_x_i)

        return torch.stack(new_x, dim=-1), x_params

    def forward(self, batch, mode=True):
        x, mask, _ = batch
        return self._impute_step(x, mask, mode=mode)[0]

    # Measures
    def log_likelihood(self, x, mask, i, params_i):
        x_i = x[..., i]
        # if str(self.prob_model[i]) == 'lognormal':
        #     x_i = torch.log1p(x_i)

        log_prob_i = self.prob_model[i].log_prob(x_i, params_i)
        if mask is not None:
            log_prob_i = log_prob_i * mask[..., i].float()
        return log_prob_i

    def _log_likelihood(self, x, x_params):
        log_prob = []
        for i, [idxs, dist_i] in enumerate(self.prob_model.gathered):
            x_i = x[..., i]
            # if str(dist_i) == 'lognormal':
            #     x_i = torch.log1p(x_i)

            params = torch.cat([x_params[i] for i in idxs], dim=0)
            log_prob_i = dist_i.real_log_prob(x_i, params)
            log_prob.append(log_prob_i)

        return torch.stack(log_prob, dim=-1).squeeze(dim=0)  # batch_size x num_dimensions

    def log_likelihood_real(self, x, mask):
        x_params = self._infer_step(x, mask, mode=True)
        return self._log_likelihood(x, x_params)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam([
            {'params': self.parameters(), 'lr': 0.001}
        ])


        return optimizer

        # We cannot set different schedulers if we want to avoid manual optimization
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.decay)

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch'  # Alternatively: "step"
            },
        }
    
    def mytrain(self, x_train,  batch_size, n_epochs):
        
        n_epochs = n_epochs//2
        trainer = pl.Trainer(max_epochs=n_epochs)
        train_dataset = x_train.float()
        train_data = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)

        # Train
        trainer.fit(self, train_data)
    
    def get_novelty_score(self, model_fake, data):
        self.eval()
        
        x = data.to('cpu')
        mask = torch.ones_like(x)
        log_px_z, kl_z, kl_s = self._run_step(x, mask)
        elbo = sum(log_px_z) - kl_z - kl_s
        loss = -elbo
        
        #print(log_px_z, sum(log_px_z))
        conf = sum(log_px_z)
        pred = np.zeros(conf.shape)

        return pred, conf.detach().cpu().numpy()